diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..d5d923ad65cf1e58295af06451ed02ad14331c95 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +examples/voice_05.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_07.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_08.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_09.wav filter=lfs diff=lfs merge=lfs -text +examples/emo_sad.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_02.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_06.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_10.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_11.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_12.wav filter=lfs diff=lfs merge=lfs -text +examples/emo_hate.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_01.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_03.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_04.wav filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9f20088d24601f967d74587f5a0777d93c81d0c4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +venv/ +__pycache__ +*.egg-info +*.DS_Store +.idea/ +.vscode/ +checkpoints/*.pth +checkpoints/*.vocab +checkpoints/*.model +checkpoints/.cache +outputs/ +build/ +*.py[cod] +*.egg-info/ +.venv +checkpoints/* \ No newline at end of file diff --git a/INDEX_MODEL_LICENSE_EN.txt b/INDEX_MODEL_LICENSE_EN.txt new file mode 100644 index 0000000000000000000000000000000000000000..ddc3bcd0f6a3d4d2a9e1475c1fbe78e900630f0c --- /dev/null +++ b/INDEX_MODEL_LICENSE_EN.txt @@ -0,0 +1,58 @@ +bilibili Model Use License Agreement + +By clicking “I agree” to this bilibili Model Use License Agreement (“this Agreement”) , or by otherwise using any portion or element of the Model or any Derivative Work, you will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately. If you do not agree to this Agreement, you must immediately cease all use and permanently delete the Model and any Derivative Works. + +1. Definitions +1.1 “This Agreement”: means the bilibili Model Use License Agreement, including all of its terms and conditions. +1.2 “We”, “us”, or “our”: means bilibili , the original right-holder of the Model. +1.3 “You”: means any natural person or legal entity exercising rights granted by this Agreement and/or using the Model for any purpose and in any field of use. +1.4 “Model”: means the artificial-intelligence model named “bilibili indextts2”, including but not limited to model weights and final code, in each case only to the extent that such components are published by us at https://github.com/index-tts/index-tts. +1.5 “Derivative Work”: means any derivative of the Model, including without limitation: + (i) any modification of the Model, model outputs, or their derivatives; + (ii) any work based on the Model, model outputs, or their derivatives; + (iii) any other machine learning model which is created by re-training, fine-tuning, quantizing, LoRA, parameter-efficient fine-tuning, or any other method involving incremental weights or merged checkpoints, in each case based on the Model, model outputs, or their derivatives. +1.6 “Use”: means downloading, copying, training, modifying, creating Derivative Works, distributing, publishing, running, fine-tuning, publicly displaying, communicating to the public, or otherwise exploiting the Model or any Derivative Work. + +2. Scope of License and Restrictions +2.1 Subject to the terms and conditions of this Agreement, we grant you a worldwide, non-exclusive, non-transferable, royalty-free limited license to Use the Model or any Derivative Work based on the intellectual properties or other rights owned by Us embodied in the Model or any Derivative Work. +2.2 If You intend to Use, or have already Used, the Model or any Derivative Work, and either (i) your or any of your Affiliates’ products or services had more than 100 million monthly active users in the immediately preceding calendar month, or (ii) your or any of your Affiliates’ annual revenue in the immediately preceding calendar year exceeded RMB 1 billion, You must request a separated license from us, which We may grant to You in our sole discretion. You are not authorized to exercise any of the rights under this Agreement unless and until We have expressly granted You such rights in writing. +2.3 This Agreement is an open-source license for the Model in which we possess intellectual properties and other rights. It governs your Use of the Model only and does not limit any rights that we have regarding the Model. + +3. Disclaimer and Risk Allocation +3.1 The Model and any outputs generated thereby are provided “AS IS,” without warranty of any kind, express or implied, including but not limited to warranties of merchantability, fitness for a particular purpose, non-infringement, absence of errors or omissions, continuity, accuracy, reliability, or stability. You are solely responsible for determining the appropriateness of using or redistributing the Model and assume all risks associated with exercising any rights granted under this Agreement. +3.2 You shall bear sole responsibility for any infringement, illegality, breach of contract, damages, fines, regulatory investigations, or other liabilities (including, without limitation, infringement of third-party patents, copyrights, trademarks, trade secrets, personality rights, data-protection rights, or any other rights) arising out of or related to your Use of the Model or any outputs generated thereby. We assume no joint, several, supplementary, or advance payment liability. +3.3 Under no circumstances shall we be liable to you or any third party for any direct, indirect, incidental, special, punitive, or consequential damages (including, without limitation, loss of data, business interruption, or loss of profits) arising out of or related to the Use of the Model, even if we have been advised of the possibility of such damages. +3.4 Additional Obligations for You and Downstream Recipients +a) You must ensure that any downstream recipient of the Model or any Derivative Work that you distribute complies with this Agreement, and you must impose appropriate contractual terms on such downstream recipients. If any downstream recipient breaches this Agreement, you shall be responsible for the consequences thereof. +b) You must retain all original copyright notices and a copy of this Agreement in every copy of the Model or any Derivative Work that you Use. +c) You may not Use the bilibili indextts2 or any Derivative Work to improve any AI model, except for the bilibili indextts2 itself, its Derivative Works,or non-commercial AI models. + +4. Compliance Obligations +4.1 Usage Restrictions +a) If you distribute a Derivative Work, you must clearly state in the distribution page or accompanying documentation: “Any modifications made to the original model in this Derivative Work are not endorsed, warranted, or guaranteed by the original right-holder of the original model, and the original right-holder disclaims all liability related to this Derivative Work.” +b) If your Use of the Model or any Derivative Work incorporates any third-party data or weights, you must obtain all necessary authorizations on your own and bear full responsibility for compliance. +c) You may not Use the Model or any Derivative Work for any purpose that violates the laws or regulatory requirements of the jurisdiction where the outputs and/or the Model are generated or used (including, without limitation, generating false information, discriminatory content, or content that infringes privacy). +d) If the Model or any Derivative Work is capable of generating content, you must ensure that such content does not violate the laws or regulatory requirements of the applicable jurisdiction (including, without limitation, generating false information, discriminatory content, or content that infringes privacy). +4.2 Prohibited High-Risk Use +You must ensure that the Model and any Derivative Work are not deployed, directly or indirectly, in high-risk scenarios such as medical diagnosis, autonomous driving, military applications, critical-infrastructure control, large-scale biometric surveillance, or automated decision-making (e.g., credit or employment evaluations). If you insist on such deployment, you must independently complete all compliance obligations under applicable laws and regulations (including but not limited to GDPR, CCPA, HIPAA, export-control laws, and AI-specific regulations), and we shall bear no liability for any consequences arising therefrom. +4.3 Infringement Liability +Should any third party raise claims against you with respect to any Derivative Work you develop or your Use of the Model or any Derivative Work, you shall bear full and independent responsibility for defending against and resolving such claims. If your actions cause us to incur any third-party claims, administrative penalties, or other losses, you shall indemnify us for all losses we thereby suffer, including but not limited to attorney fees, litigation costs, damages, and fines, and shall take all necessary measures to eliminate any adverse impact on us. + +5. Reserved Rights +5.1 We reserve the right to revoke the license granted to you under this Agreement in the event of your breach. Upon revocation, you must immediately cease all Use and permanently delete all copies of the Model and any Derivative Work. Sections 3 and 6 of this Agreement shall survive termination of this Agreement under this circumstance. +5.2 Nothing in this Agreement grants you any right to use our trade names, trademarks, service marks, or product names, except as reasonably and customarily required to describe the origin of the Model or any Derivative Work—such as reproducing the content of a NOTICE file under Section 3.4 of this Agreement. +5.3 If you or any of your Affiliates institutes or participates in any legal proceeding (including any cross-claim or counterclaim in a lawsuit) against us or any of our Affiliates, alleging that the Model or any output or any portion thereof infringes any intellectual property or other rights that you own or control, all licenses granted to you under this Agreement shall terminate automatically as of the date such proceeding is filed. + +6. Governing Law and Dispute Resolution +6.1 This Agreement shall be governed by and construed in accordance with the laws of the People’s Republic of China. +6.2 In the event of any dispute arising out of or in connection with this Agreement, the parties shall first attempt to resolve such dispute through friendly negotiation. If negotiation fails, the dispute shall be submitted to the Shanghai Arbitration Commission for arbitration in accordance with its then-effective arbitration rules. The arbitration award shall be final and binding on both parties. The prevailing party shall be entitled to recover reasonable costs, including notarization and investigation fees, arbitration costs, attorneys’ fees, and travel expenses. + +7. Severability +If any provision of this Agreement is held to be invalid or unenforceable, the remaining provisions shall remain in full force and effect. The invalid or unenforceable provision shall be replaced with a valid and enforceable provision that, to the maximum extent permitted by law, most closely reflects the original intent of the invalid or unenforceable provision. + +8. Version Updates +We may release new versions of the AI Model Use License Agreement. Any new version will apply only to Uses occurring after the date of its release. If you obtained the Model under an earlier version, the new version will not have retroactive effect; nevertheless, you are encouraged to adopt the new version voluntarily. + +9. Language Version +In the event of any discrepancy or conflict between the English-language version set forth above and the Chinese-language version of this bilibili Model Use License Agreement, the Chinese-language version shall prevail for all purposes and shall govern the rights and obligations of the parties. + diff --git a/INDEX_MODEL_LICENSE_ZH.txt b/INDEX_MODEL_LICENSE_ZH.txt new file mode 100644 index 0000000000000000000000000000000000000000..519ba4429a9404a686990f201e55f9a3b2bc93f6 --- /dev/null +++ b/INDEX_MODEL_LICENSE_ZH.txt @@ -0,0 +1,52 @@ +bilibili模型使用许可协议 + +若您点击同意《bilibili模型使用许可协议》(“本协议”),或使用我方模型或衍生品的任何部分或元素,即视为您已确认并接受本协议内容,本协议立即生效。若您不同意本协议,应立即停止使用并删除模型及衍生品。 + +1.定义 +1.1 本协议:指《bilibili 模型使用许可协议》,包括本协议所规定的所有条款和条件。 +1.2 我方:指bilibili即模型的原始权利人。 +1.3 您:指行使本许可协议授予的权利和/或使用“模型”的自然人或法人实体。 +1.4 模型:指名为“bilibili indextts2”的AI模型,包括模型权重、最终代码等组件,具体范围以我方在https://github.com/index-tts/index-tts发布的组件为限。 +1.5 衍生品:指模型的衍生品,包括但不限于:(i)对模型、模型输出及其衍生品的修改;(ii)基于模型、模型输出及其衍生品的创作;(iii)对模型、模型输出及其衍生品再训练、微调、量化、LoRA、参数高效微调、以任何增量权重或合并的检查点等方式创建的任何模型。 +1.6 使用:指通过下载、复制、训练、修改、创作衍生品、分发、发布、运行、微调、公开展示、传播或以其他方式利用本模型或其衍生品的行为。 + +2. 许可范围和限制 +2.1 根据本协议的条款与条件,基于对模型或其衍生品中包含的我方拥有的任何知识产权和其他权利,我方特此授予您一项全球范围、非独占、不可转让、免费的使用许可。 +2.2若您拟使用或者已使用我方模型或其衍生品,如果您或者您的关联方提供的产品或服务在前一自然月的月活跃用户数超过1亿,或者如果您或者您的关联方在上一自然年的年收入超过1亿人民币的,您必须向我方申请该模型或其衍生品的商业许可,我方可自行决定是否授予您该许可。您无权行使本协议项下的任何权利,除非我方另行明确授予您该等许可。 +2.3 本协议作为我方享有知识产权和其他权利的模型的开源许可协议,仅约束您对我方模型的使用行为,并不限制我方对该模型享有的任何权利。 + +3. 免责声明与风险约定 +3.1 模型及其任何输出均“按原样”提供,我方及其关联方不提供任何形式的明示或暗示的保证,包括但不限于适销性、特定用途适用性、不侵权、没有错误或疏漏、持续性、准确性、可靠性、稳定性的保证。您需自行负责判断使用或再分发本作品的适当性,并承担行使本许可证所授予权限相关的所有风险。 +3.2 您因使用模型或利用其输出内容而产生的任何侵权、违法、违约、赔偿、罚款、监管调查或其他法律责任(包括但不限于侵犯第三方专利、版权、商标、商业秘密、人格权、数据保护权等),均由您独自承担。我方不承担任何连带责任、补充责任或垫付责任。 +3.3 在任何情况下,我方对因使用本模型而产生的任何直接、间接、附带、特殊、惩罚性或后果性损失(包括但不限于数据丢失、业务中断、利润损失等)不承担责任,即使我方已被告知该等损失的可能性。 +3.4 对您和下游用户的其他约束 +a)您应确保下游用户在使用您发布的本模型或您基于本模型开发的衍生品时,同样遵守本协议的相关规定,并通过合适的协议或条款对下游用户进行约束。若下游用户违反本协议规定,您需承担相应责任。 +b)您需在您使用的本模型或您基于本模型开发的衍生品的所有副本中保留原始版权声明及本使用许可协议。 +c)您不得使用bilibili indextts2或其衍生品来改进任何AI模型(bilibili indextts2或其衍生品、非商业用途的AI模型除外)。 + +4. 合规义务 +4.1使用限制 +a) 若您发布模型的衍生品,必须在发布页面或附随文档中清晰声明“该衍生品对原模型所作的任何改动与原模型原始权利人无关,原始权利人对该衍生品不背书、不担保、不承担责任”。 +b) 若您使用模型或模型衍生品的过程中引入任何第三方数据或权重,您须自行取得合法授权并承担全部合规责任。 +c) 不得将模型及模型衍生品用于违反输出地/使用地法律或监管要求的用途(包括但不限于生成虚假信息、歧视性内容、侵犯隐私等)。 +d) 若模型或模型衍生品具备生成内容功能,您须确保其输出内容不违反输出地/使用地法律或监管要求的用途(包括但不限于生成虚假信息、歧视性内容、侵犯隐私等)。 +4.2 禁止高风险场景 +您须自行确保不在医疗诊断、自动驾驶、军事、关键基础设施控制、大规模生物识别监控、自动化决策(如信贷、就业评估)等高风险场景直接部署本模型及其衍生品。若您坚持部署,应自行完成符合适用法规(包括 GDPR、CCPA、HIPAA、出口管制、AI 特定法规等)的全部合规要求,我方对因此产生的任何后果概不负责。 +4.3 侵权责任 +如第三方就您开发的模型衍生品或您使用模型或其衍生品等行为主张权利,您应独立承担全部责任。若因您的行为导致我方遭受任何第三方索赔、行政处罚或其他损失,您应负责赔偿我方因此遭受的全部损失,包括但不限于律师费、诉讼费、赔偿金、罚款等,并采取一切必要措施消除对我方的负面影响。 + +5. 保留权利 +5.1我方保留在您违反协议的情况下撤销本协议对您授权之权利。协议撤销后,您必须立即删除并停止使用材料。在本协议终止后,本协议第3条、第6条仍然有效。 +5.2 本许可证不授予使用我方的商号、商标、服务标记或产品名称的权限,除非在合理且惯例性地描述模型或衍生品的来源,例如本许可证3.4的规定,以及复制 NOTICE 文件内容时需要使用。 +5.3 若您或您的关联方对我方或我方任何关联实体提起诉讼或其他程序(包括诉讼中的交叉索赔或反诉),主张模型或其任何输出结果或其任何部分侵犯了您拥有或可许可的知识产权或其他权利,则本协议授予您的所有许可自该诉讼或程序提起之日起终止。 + +6. 法律适用与争议解决 +6.1 本协议适用中华人民共和国法律法规。 +6.2 在本协议履行中,若发生争议,双方应本着友好协商的原则解决问题;如协商不成,双方均应将争议提交至上海仲裁委员会根据其仲裁规则进行仲裁,仲裁是一裁终局的,对双方均有约束力。由仲裁败诉方承担本次仲裁产生的公证调查费、仲裁费、律师费、差旅费等实际产生费用。 + +7. 可分割性 +若本协议任何条款被认定为无效或不可执行,不影响其余条款之效力;无效部分应在法律允许的最大范围内按最接近原意的有效条款替代。 + +8. 协议版本更新 +我方可发布新版 AI模型使用许可协议。新版仅适用于发布后新产生的使用行为,若您已按旧版获取模型,新版协议并无溯及力,但鼓励您主动更新。 + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..ad8e99e4a4bf55dc04fed894127f7907aa44d209 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +global-exclude *~ *.py[cod] +include *.cu *.cpp +include *.h *.hpp diff --git a/README.md b/README.md index 7ceda36b56af1c1a6098c832ea52bf7b75aa0c54..de0074adbfad630ecded29113d9b9416d666f696 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,13 @@ colorFrom: yellow colorTo: gray sdk: gradio sdk_version: 5.34.1 -app_file: app.py +app_file: webui.py pinned: false -license: apache-2.0 +license: GPL-3.0 +preload_from_hub: + - IndexTeam/IndexTTS-2 + - amphion/MaskGCT + - funasr/campplus --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/archive/README_INDEXTTS_1_5.md b/archive/README_INDEXTTS_1_5.md new file mode 100644 index 0000000000000000000000000000000000000000..16c52c0400d6afea9d2545d3b141b3935909bb1b --- /dev/null +++ b/archive/README_INDEXTTS_1_5.md @@ -0,0 +1,247 @@ + +
+ +
+ + +

IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System

+ +

+ + +## 👉🏻 IndexTTS 👈🏻 + +[[HuggingFace Demo]](https://huggingface.co/spaces/IndexTeam/IndexTTS) [[ModelScope Demo]](https://modelscope.cn/studios/IndexTeam/IndexTTS-Demo) \ +[[Paper]](https://arxiv.org/abs/2502.05512) [[Demos]](https://index-tts.github.io) + +**IndexTTS** is a GPT-style text-to-speech (TTS) model mainly based on XTTS and Tortoise. It is capable of correcting the pronunciation of Chinese characters using pinyin and controlling pauses at any position through punctuation marks. We enhanced multiple modules of the system, including the improvement of speaker condition feature representation, and the integration of BigVGAN2 to optimize audio quality. Trained on tens of thousands of hours of data, our system achieves state-of-the-art performance, outperforming current popular TTS systems such as XTTS, CosyVoice2, Fish-Speech, and F5-TTS. + +Experience **IndexTTS**: Please contact xuanwu@bilibili.com for more detailed information. +### Contact +QQ群(二群):1048202584 \ +Discord:https://discord.gg/uT32E7KDmy \ +简历:indexspeech@bilibili.com \ +欢迎大家来交流讨论! +## 📣 Updates + +- `2025/05/14` 🔥🔥 We release the **IndexTTS-1.5**, Significantly improve the model's stability and its performance in the English language. +- `2025/03/25` 🔥 We release IndexTTS-1.0 model parameters and inference code. +- `2025/02/12` 🔥 We submitted our paper on arXiv, and released our demos and test sets. + +## 🖥️ Method + +The overview of IndexTTS is shown as follows. + + + + + + +The main improvements and contributions are summarized as follows: + - In Chinese scenarios, we have introduced a character-pinyin hybrid modeling approach. This allows for quick correction of mispronounced characters. + - **IndexTTS** incorporate a conformer conditioning encoder and a BigVGAN2-based speechcode decoder. This improves training stability, voice timbre similarity, and sound quality. + - We release all test sets here, including those for polysyllabic words, subjective and objective test sets. + + + +## Model Download +| 🤗**HuggingFace** | **ModelScope** | +|----------------------------------------------------------|----------------------------------------------------------| +| [IndexTTS](https://huggingface.co/IndexTeam/Index-TTS) | [IndexTTS](https://modelscope.cn/models/IndexTeam/Index-TTS) | +| [😁IndexTTS-1.5](https://huggingface.co/IndexTeam/IndexTTS-1.5) | [IndexTTS-1.5](https://modelscope.cn/models/IndexTeam/IndexTTS-1.5) | + + +## 📑 Evaluation + +**Word Error Rate (WER) Results for IndexTTS and Baseline Models on the** [**seed-test**](https://github.com/BytedanceSpeech/seed-tts-eval) + +| **WER** | **test_zh** | **test_en** | **test_hard** | +|:----------------------:|:-----------:|:-----------:|:-------------:| +| **Human** | 1.26 | 2.14 | - | +| **SeedTTS** | 1.002 | 1.945 | **6.243** | +| **CosyVoice 2** | 1.45 | 2.57 | 6.83 | +| **F5TTS** | 1.56 | 1.83 | 8.67 | +| **FireRedTTS** | 1.51 | 3.82 | 17.45 | +| **MaskGCT** | 2.27 | 2.62 | 10.27 | +| **Spark-TTS** | 1.2 | 1.98 | - | +| **MegaTTS 3** | 1.36 | 1.82 | - | +| **IndexTTS** | 0.937 | 1.936 | 6.831 | +| **IndexTTS-1.5** | **0.821** | **1.606** | 6.565 | + + +**Word Error Rate (WER) Results for IndexTTS and Baseline Models on the other opensource test** + + +| **Model** | **aishell1_test** | **commonvoice_20_test_zh** | **commonvoice_20_test_en** | **librispeech_test_clean** | **avg** | +|:---------------:|:-----------------:|:--------------------------:|:--------------------------:|:--------------------------:|:--------:| +| **Human** | 2.0 | 9.5 | 10.0 | 2.4 | 5.1 | +| **CosyVoice 2** | 1.8 | 9.1 | 7.3 | 4.9 | 5.9 | +| **F5TTS** | 3.9 | 11.7 | 5.4 | 7.8 | 8.2 | +| **Fishspeech** | 2.4 | 11.4 | 8.8 | 8.0 | 8.3 | +| **FireRedTTS** | 2.2 | 11.0 | 16.3 | 5.7 | 7.7 | +| **XTTS** | 3.0 | 11.4 | 7.1 | 3.5 | 6.0 | +| **IndexTTS** | 1.3 | 7.0 | 5.3 | 2.1 | 3.7 | +| **IndexTTS-1.5** | **1.2** | **6.8** | **3.9** | **1.7** | **3.1** | + + +**Speaker Similarity (SS) Results for IndexTTS and Baseline Models** + +| **Model** | **aishell1_test** | **commonvoice_20_test_zh** | **commonvoice_20_test_en** | **librispeech_test_clean** | **avg** | +|:---------------:|:-----------------:|:--------------------------:|:--------------------------:|:--------------------------:|:---------:| +| **Human** | 0.846 | 0.809 | 0.820 | 0.858 | 0.836 | +| **CosyVoice 2** | **0.796** | 0.743 | 0.742 | **0.837** | **0.788** | +| **F5TTS** | 0.743 | **0.747** | 0.746 | 0.828 | 0.779 | +| **Fishspeech** | 0.488 | 0.552 | 0.622 | 0.701 | 0.612 | +| **FireRedTTS** | 0.579 | 0.593 | 0.587 | 0.698 | 0.631 | +| **XTTS** | 0.573 | 0.586 | 0.648 | 0.761 | 0.663 | +| **IndexTTS** | 0.744 | 0.742 | **0.758** | 0.823 | 0.776 | +| **IndexTTS-1.5** | 0.741 | 0.722 | 0.753 | 0.819 | 0.771 | + + + +**MOS Scores for Zero-Shot Cloned Voice** + +| **Model** | **Prosody** | **Timbre** | **Quality** | **AVG** | +|-----------------|:-----------:|:----------:|:-----------:|:---------:| +| **CosyVoice 2** | 3.67 | 4.05 | 3.73 | 3.81 | +| **F5TTS** | 3.56 | 3.88 | 3.56 | 3.66 | +| **Fishspeech** | 3.40 | 3.63 | 3.69 | 3.57 | +| **FireRedTTS** | 3.79 | 3.72 | 3.60 | 3.70 | +| **XTTS** | 3.23 | 2.99 | 3.10 | 3.11 | +| **IndexTTS** | **3.79** | **4.20** | **4.05** | **4.01** | + + +## Usage Instructions +### Environment Setup +1. Download this repository: +```bash +git clone https://github.com/index-tts/index-tts.git +``` +2. Install dependencies: + +Create a new conda environment and install dependencies: + +```bash +conda create -n index-tts python=3.10 +conda activate index-tts +apt-get install ffmpeg +# or use conda to install ffmpeg +conda install -c conda-forge ffmpeg +``` + +Install [PyTorch](https://pytorch.org/get-started/locally/), e.g.: +```bash +pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118 +``` + +> [!NOTE] +> If you are using Windows you may encounter [an error](https://github.com/index-tts/index-tts/issues/61) when installing `pynini`: +`ERROR: Failed building wheel for pynini` +> In this case, please install `pynini` via `conda`: +> ```bash +> # after conda activate index-tts +> conda install -c conda-forge pynini==2.1.6 +> pip install WeTextProcessing --no-deps +> ``` + +Install `IndexTTS` as a package: +```bash +cd index-tts +pip install -e . +``` + +3. Download models: + +Download by `huggingface-cli`: + +```bash +huggingface-cli download IndexTeam/IndexTTS-1.5 \ + config.yaml bigvgan_discriminator.pth bigvgan_generator.pth bpe.model dvae.pth gpt.pth unigram_12000.vocab \ + --local-dir checkpoints +``` + +Recommended for China users. 如果下载速度慢,可以使用镜像: +```bash +export HF_ENDPOINT="https://hf-mirror.com" +``` + +Or by `wget`: + +```bash +wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/bigvgan_discriminator.pth -P checkpoints +wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/bigvgan_generator.pth -P checkpoints +wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/bpe.model -P checkpoints +wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/dvae.pth -P checkpoints +wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/gpt.pth -P checkpoints +wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/unigram_12000.vocab -P checkpoints +wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/config.yaml -P checkpoints +``` + +> [!NOTE] +> If you prefer to use the `IndexTTS-1.0` model, please replace `IndexTeam/IndexTTS-1.5` with `IndexTeam/IndexTTS` in the above commands. + + +4. Run test script: + + +```bash +# Please put your prompt audio in 'test_data' and rename it to 'input.wav' +python indextts/infer.py +``` + +5. Use as command line tool: + +```bash +# Make sure pytorch has been installed before running this command +indextts "大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!" \ + --voice reference_voice.wav \ + --model_dir checkpoints \ + --config checkpoints/config.yaml \ + --output output.wav +``` + +Use `--help` to see more options. +```bash +indextts --help +``` + +#### Web Demo +```bash +pip install -e ".[webui]" --no-build-isolation +python webui.py + +# use another model version: +python webui.py --model_dir IndexTTS-1.5 +``` + +Open your browser and visit `http://127.0.0.1:7860` to see the demo. + + +#### Sample Code +```python +from indextts.infer import IndexTTS +tts = IndexTTS(model_dir="checkpoints",cfg_path="checkpoints/config.yaml") +voice="reference_voice.wav" +text="大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!比如说,现在正在说话的其实是B站为我现场复刻的数字分身,简直就是平行宇宙的另一个我了。如果大家也想体验更多深入的AIGC功能,可以访问 bilibili studio,相信我,你们也会吃惊的。" +tts.infer(voice, text, output_path) +``` + +## Acknowledge +1. [tortoise-tts](https://github.com/neonbjb/tortoise-tts) +2. [XTTSv2](https://github.com/coqui-ai/TTS) +3. [BigVGAN](https://github.com/NVIDIA/BigVGAN) +4. [wenet](https://github.com/wenet-e2e/wenet/tree/main) +5. [icefall](https://github.com/k2-fsa/icefall) + +## 📚 Citation + +🌟 If you find our work helpful, please leave us a star and cite our paper. + +``` +@article{deng2025indextts, + title={IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System}, + author={Wei Deng, Siyi Zhou, Jingchen Shu, Jinchao Wang, Lu Wang}, + journal={arXiv preprint arXiv:2502.05512}, + year={2025} +} +``` diff --git a/examples/cases.jsonl b/examples/cases.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..d688b8d1dc8a10408d8d911fde418d80f8b1f950 --- /dev/null +++ b/examples/cases.jsonl @@ -0,0 +1,12 @@ +{"prompt_audio":"voice_01.wav","text":"Translate for me,what is a surprise!","emo_mode":0} +{"prompt_audio":"voice_02.wav","text":"The palace is strict, no false rumors, Lady Qi!","emo_mode":0} +{"prompt_audio":"voice_03.wav","text":"这个呀,就是我们精心制作准备的纪念品,大家可以看到这个色泽和这个材质啊,哎呀多么的光彩照人。","emo_mode":0} +{"prompt_audio":"voice_04.wav","text":"你就需要我这种专业人士的帮助,就像手无缚鸡之力的人进入雪山狩猎,一定需要最老练的猎人指导。","emo_mode":0} +{"prompt_audio":"voice_05.wav","text":"在真正的日本剑道中,格斗过程极其短暂,常常短至半秒,最长也不超过两秒,利剑相击的转瞬间,已有一方倒在血泊中。但在这电光石火的对决之前,双方都要以一个石雕般凝固的姿势站定,长时间的逼视对方,这一过程可能长达十分钟!","emo_mode":0} +{"prompt_audio":"voice_06.wav","text":"今天呢,咱们开一部新书,叫《赛博朋克二零七七》。这词儿我听着都新鲜。这赛博朋克啊,简单理解就是“高科技,低生活”。这一听,我就明白了,于老师就爱用那高科技的东西,手机都得拿脚纹开,大冬天为了解锁脱得一丝不挂,冻得跟王八蛋似的。","emo_mode":0} +{"prompt_audio":"voice_07.wav","emo_audio":"emo_sad.wav","emo_weight": 0.9, "emo_mode":1,"text":"酒楼丧尽天良,开始借机竞拍房间,哎,一群蠢货。"} +{"prompt_audio":"voice_08.wav","emo_audio":"emo_hate.wav","emo_weight": 0.8, "emo_mode":1,"text":"你看看你,对我还有没有一点父子之间的信任了。"} +{"prompt_audio":"voice_09.wav","emo_vec_3":0.55,"emo_mode":2,"text":"对不起嘛!我的记性真的不太好,但是和你在一起的事情,我都会努力记住的~"} +{"prompt_audio":"voice_10.wav","emo_vec_7":0.45,"emo_mode":2,"text":"哇塞!这个爆率也太高了!欧皇附体了!"} +{"prompt_audio":"voice_11.wav","emo_mode":3,"emo_text":"极度悲伤","text":"这些年的时光终究是错付了... "} +{"prompt_audio":"voice_12.wav","emo_mode":3,"emo_text":"You scared me to death! What are you, a ghost?","text":"快躲起来!是他要来了!他要来抓我们了!"} \ No newline at end of file diff --git a/examples/emo_hate.wav b/examples/emo_hate.wav new file mode 100644 index 0000000000000000000000000000000000000000..5cfd83ecc32bf809f30a3939d96af120a72824d1 --- /dev/null +++ b/examples/emo_hate.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:89e6e7eee1a28303776e9cf43971e9505529bd0e669f5fcf47f4d1370f9187c4 +size 145368 diff --git a/examples/emo_sad.wav b/examples/emo_sad.wav new file mode 100644 index 0000000000000000000000000000000000000000..be5ce105030b4d4e842318be0e2f8d17b45637e9 --- /dev/null +++ b/examples/emo_sad.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7d3e5bf2b7bca6458f9e6d7a5ce073c41eb4418895e7df2f994e5a0c96c064a +size 842016 diff --git a/examples/voice_01.wav b/examples/voice_01.wav new file mode 100644 index 0000000000000000000000000000000000000000..65c02c7b7bd67750102005e16d390c55bd945268 --- /dev/null +++ b/examples/voice_01.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e33e6ee0107a1dd58e1d66dd90c13df3d55a8683047cc3d7ea206dad84ed3fc8 +size 478050 diff --git a/examples/voice_02.wav b/examples/voice_02.wav new file mode 100644 index 0000000000000000000000000000000000000000..4edcdb69d88d07c9209a7b1a5e281f3be77f941d --- /dev/null +++ b/examples/voice_02.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fe2dd1dbd54ef85a073fbc4c8fc0198f8d4523cc3320a600de0e347a3d8b491 +size 574074 diff --git a/examples/voice_03.wav b/examples/voice_03.wav new file mode 100644 index 0000000000000000000000000000000000000000..ead51e2db65a2251772c40f69c58da25b0bf9c52 --- /dev/null +++ b/examples/voice_03.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50e8b632efd794418919e2d33c8c2aab9189a57f4d21ef55020413be9f2b292a +size 616814 diff --git a/examples/voice_04.wav b/examples/voice_04.wav new file mode 100644 index 0000000000000000000000000000000000000000..7d74fec94b191109b61b78de7461f934cec66b5f --- /dev/null +++ b/examples/voice_04.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a3d2536245f45fd5e1eef046dd768ae7b72a0dba3ec3f370f145862fe64b3b2 +size 681084 diff --git a/examples/voice_05.wav b/examples/voice_05.wav new file mode 100644 index 0000000000000000000000000000000000000000..df4ae60230c8e248f618ce710eda6fe8b34fd0d2 --- /dev/null +++ b/examples/voice_05.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eefb7f4a29a8b36f08d5cc1014ea947dbe9f7bef348f07c40263058e604a98eb +size 1482796 diff --git a/examples/voice_06.wav b/examples/voice_06.wav new file mode 100644 index 0000000000000000000000000000000000000000..a28efcbd51fbff390d00c9fe0594ca72a184955a --- /dev/null +++ b/examples/voice_06.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d85800fe261d106c3274fa792cbb952458c4b0b2e1b908340a8cd0d63c73a30 +size 299052 diff --git a/examples/voice_07.wav b/examples/voice_07.wav new file mode 100644 index 0000000000000000000000000000000000000000..9927f437bac3cfe7ce4ae5f9cc8d864f8b812876 --- /dev/null +++ b/examples/voice_07.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bcb10f84e63c3fdbfe99ac4184ca403b46a6d20b50540732713d48c4c95375ce +size 591894 diff --git a/examples/voice_08.wav b/examples/voice_08.wav new file mode 100644 index 0000000000000000000000000000000000000000..9d72b554db7d69fe0041a7c9557a4f8a27762873 --- /dev/null +++ b/examples/voice_08.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e2c5f4859999b1ada95ee801d50c3c72879147269a4ed99e385fd917dae5c6f +size 426812 diff --git a/examples/voice_09.wav b/examples/voice_09.wav new file mode 100644 index 0000000000000000000000000000000000000000..81b98acf10f0d51099f2bea2973d55d9dc6c0185 --- /dev/null +++ b/examples/voice_09.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8702467b9b3c83a16bead578e131c4388b3ef82aeff861bd336e622a9ae8a511 +size 1798188 diff --git a/examples/voice_10.wav b/examples/voice_10.wav new file mode 100644 index 0000000000000000000000000000000000000000..a312359ca6d6330580e1022c595cc71919ef1449 --- /dev/null +++ b/examples/voice_10.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39c2db8b395e4c6ea1122ec7463b5f7bd7dd7d7302f3255780e4c529a9ae9985 +size 1942242 diff --git a/examples/voice_11.wav b/examples/voice_11.wav new file mode 100644 index 0000000000000000000000000000000000000000..8a3462550965df507d753a3f6330a3891abfcf33 --- /dev/null +++ b/examples/voice_11.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:82730e38498413d4371a76e841cd91fa2f74843b79ad3b606d45ad8a7b7a736c +size 1520734 diff --git a/examples/voice_12.wav b/examples/voice_12.wav new file mode 100644 index 0000000000000000000000000000000000000000..d2a2e2901395d4d5803a61f0bcfb3e9cc9eef0d7 --- /dev/null +++ b/examples/voice_12.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d67bd4f51773677d5902409813b9bb4c1d59b8243c74fc104553b80b49edd22b +size 778626 diff --git a/indextts/.DS_Store b/indextts/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b6bfcde0ba6b1155347cd2f8876973f81f7502c8 Binary files /dev/null and b/indextts/.DS_Store differ diff --git a/indextts/BigVGAN/.DS_Store b/indextts/BigVGAN/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e497d6903a8aa2f749319cba2923bf8353454316 Binary files /dev/null and b/indextts/BigVGAN/.DS_Store differ diff --git a/indextts/BigVGAN/ECAPA_TDNN.py b/indextts/BigVGAN/ECAPA_TDNN.py new file mode 100644 index 0000000000000000000000000000000000000000..beea8ca8fd3356b4681df030d8164f6d2223ef8c --- /dev/null +++ b/indextts/BigVGAN/ECAPA_TDNN.py @@ -0,0 +1,656 @@ +"""A popular speaker recognition and diarization model. + +Authors + * Hwidong Na 2020 +""" + +import torch # noqa: F401 +import torch.nn as nn +import torch.nn.functional as F + +from indextts.BigVGAN.nnet.CNN import Conv1d as _Conv1d +from indextts.BigVGAN.nnet.linear import Linear +from indextts.BigVGAN.nnet.normalization import BatchNorm1d as _BatchNorm1d + + +def length_to_mask(length, max_len=None, dtype=None, device=None): + """Creates a binary mask for each sequence. + + Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 + + Arguments + --------- + length : torch.LongTensor + Containing the length of each sequence in the batch. Must be 1D. + max_len : int + Max length for the mask, also the size of the second dimension. + dtype : torch.dtype, default: None + The dtype of the generated mask. + device: torch.device, default: None + The device to put the mask variable. + + Returns + ------- + mask : tensor + The binary mask. + + Example + ------- + >>> length=torch.Tensor([1,2,3]) + >>> mask=length_to_mask(length) + >>> mask + tensor([[1., 0., 0.], + [1., 1., 0.], + [1., 1., 1.]]) + """ + assert len(length.shape) == 1 + + if max_len is None: + max_len = length.max().long().item() # using arange to generate mask + mask = torch.arange( + max_len, device=length.device, dtype=length.dtype + ).expand(len(length), max_len) < length.unsqueeze(1) + + if dtype is None: + dtype = length.dtype + + if device is None: + device = length.device + + mask = torch.as_tensor(mask, dtype=dtype, device=device) + return mask + + +# Skip transpose as much as possible for efficiency +class Conv1d(_Conv1d): + """1D convolution. Skip transpose is used to improve efficiency.""" + + def __init__(self, *args, **kwargs): + super().__init__(skip_transpose=True, *args, **kwargs) + + +class BatchNorm1d(_BatchNorm1d): + """1D batch normalization. Skip transpose is used to improve efficiency.""" + + def __init__(self, *args, **kwargs): + super().__init__(skip_transpose=True, *args, **kwargs) + + +class TDNNBlock(nn.Module): + """An implementation of TDNN. + + Arguments + --------- + in_channels : int + Number of input channels. + out_channels : int + The number of output channels. + kernel_size : int + The kernel size of the TDNN blocks. + dilation : int + The dilation of the TDNN block. + activation : torch class + A class for constructing the activation layers. + groups : int + The groups size of the TDNN blocks. + + Example + ------- + >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) + >>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1) + >>> out_tensor = layer(inp_tensor).transpose(1, 2) + >>> out_tensor.shape + torch.Size([8, 120, 64]) + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + activation=nn.ReLU, + groups=1, + ): + super().__init__() + self.conv = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + groups=groups, + ) + self.activation = activation() + self.norm = BatchNorm1d(input_size=out_channels) + + def forward(self, x): + """Processes the input tensor x and returns an output tensor.""" + return self.norm(self.activation(self.conv(x))) + + +class Res2NetBlock(torch.nn.Module): + """An implementation of Res2NetBlock w/ dilation. + + Arguments + --------- + in_channels : int + The number of channels expected in the input. + out_channels : int + The number of output channels. + scale : int + The scale of the Res2Net block. + kernel_size: int + The kernel size of the Res2Net block. + dilation : int + The dilation of the Res2Net block. + + Example + ------- + >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) + >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3) + >>> out_tensor = layer(inp_tensor).transpose(1, 2) + >>> out_tensor.shape + torch.Size([8, 120, 64]) + """ + + def __init__( + self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1 + ): + super().__init__() + assert in_channels % scale == 0 + assert out_channels % scale == 0 + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.ModuleList( + [ + TDNNBlock( + in_channel, + hidden_channel, + kernel_size=kernel_size, + dilation=dilation, + ) + for i in range(scale - 1) + ] + ) + self.scale = scale + + def forward(self, x): + """Processes the input tensor x and returns an output tensor.""" + y = [] + for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)): + if i == 0: + y_i = x_i + elif i == 1: + y_i = self.blocks[i - 1](x_i) + else: + y_i = self.blocks[i - 1](x_i + y_i) + y.append(y_i) + y = torch.cat(y, dim=1) + return y + + +class SEBlock(nn.Module): + """An implementation of squeeze-and-excitation block. + + Arguments + --------- + in_channels : int + The number of input channels. + se_channels : int + The number of output channels after squeeze. + out_channels : int + The number of output channels. + + Example + ------- + >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) + >>> se_layer = SEBlock(64, 16, 64) + >>> lengths = torch.rand((8,)) + >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2) + >>> out_tensor.shape + torch.Size([8, 120, 64]) + """ + + def __init__(self, in_channels, se_channels, out_channels): + super().__init__() + + self.conv1 = Conv1d( + in_channels=in_channels, out_channels=se_channels, kernel_size=1 + ) + self.relu = torch.nn.ReLU(inplace=True) + self.conv2 = Conv1d( + in_channels=se_channels, out_channels=out_channels, kernel_size=1 + ) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, lengths=None): + """Processes the input tensor x and returns an output tensor.""" + L = x.shape[-1] + if lengths is not None: + mask = length_to_mask(lengths * L, max_len=L, device=x.device) + mask = mask.unsqueeze(1) + total = mask.sum(dim=2, keepdim=True) + s = (x * mask).sum(dim=2, keepdim=True) / total + else: + s = x.mean(dim=2, keepdim=True) + + s = self.relu(self.conv1(s)) + s = self.sigmoid(self.conv2(s)) + + return s * x + + +class AttentiveStatisticsPooling(nn.Module): + """This class implements an attentive statistic pooling layer for each channel. + It returns the concatenated mean and std of the input tensor. + + Arguments + --------- + channels: int + The number of input channels. + attention_channels: int + The number of attention channels. + global_context: bool + Whether to use global context. + + Example + ------- + >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) + >>> asp_layer = AttentiveStatisticsPooling(64) + >>> lengths = torch.rand((8,)) + >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2) + >>> out_tensor.shape + torch.Size([8, 1, 128]) + """ + + def __init__(self, channels, attention_channels=128, global_context=True): + super().__init__() + + self.eps = 1e-12 + self.global_context = global_context + if global_context: + self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1) + else: + self.tdnn = TDNNBlock(channels, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = Conv1d( + in_channels=attention_channels, out_channels=channels, kernel_size=1 + ) + + def forward(self, x, lengths=None): + """Calculates mean and std for a batch (input tensor). + + Arguments + --------- + x : torch.Tensor + Tensor of shape [N, C, L]. + lengths : torch.Tensor + The corresponding relative lengths of the inputs. + + Returns + ------- + pooled_stats : torch.Tensor + mean and std of batch + """ + L = x.shape[-1] + + def _compute_statistics(x, m, dim=2, eps=self.eps): + mean = (m * x).sum(dim) + std = torch.sqrt( + (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps) + ) + return mean, std + + if lengths is None: + lengths = torch.ones(x.shape[0], device=x.device) + + # Make binary mask of shape [N, 1, L] + mask = length_to_mask(lengths * L, max_len=L, device=x.device) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + if self.global_context: + # torch.std is unstable for backward computation + # https://github.com/pytorch/pytorch/issues/4320 + total = mask.sum(dim=2, keepdim=True).float() + mean, std = _compute_statistics(x, mask / total) + mean = mean.unsqueeze(2).repeat(1, 1, L) + std = std.unsqueeze(2).repeat(1, 1, L) + attn = torch.cat([x, mean, std], dim=1) + else: + attn = x + + # Apply layers + attn = self.conv(self.tanh(self.tdnn(attn))) + + # Filter out zero-paddings + attn = attn.masked_fill(mask == 0, float("-inf")) + + attn = F.softmax(attn, dim=2) + mean, std = _compute_statistics(x, attn) + # Append mean and std of the batch + pooled_stats = torch.cat((mean, std), dim=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class SERes2NetBlock(nn.Module): + """An implementation of building block in ECAPA-TDNN, i.e., + TDNN-Res2Net-TDNN-SEBlock. + + Arguments + --------- + in_channels: int + Expected size of input channels. + out_channels: int + The number of output channels. + res2net_scale: int + The scale of the Res2Net block. + se_channels : int + The number of output channels after squeeze. + kernel_size: int + The kernel size of the TDNN blocks. + dilation: int + The dilation of the Res2Net block. + activation : torch class + A class for constructing the activation layers. + groups: int + Number of blocked connections from input channels to output channels. + + Example + ------- + >>> x = torch.rand(8, 120, 64).transpose(1, 2) + >>> conv = SERes2NetBlock(64, 64, res2net_scale=4) + >>> out = conv(x).transpose(1, 2) + >>> out.shape + torch.Size([8, 120, 64]) + """ + + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + activation=torch.nn.ReLU, + groups=1, + ): + super().__init__() + self.out_channels = out_channels + self.tdnn1 = TDNNBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + activation=activation, + groups=groups, + ) + self.res2net_block = Res2NetBlock( + out_channels, out_channels, res2net_scale, kernel_size, dilation + ) + self.tdnn2 = TDNNBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + activation=activation, + groups=groups, + ) + self.se_block = SEBlock(out_channels, se_channels, out_channels) + + self.shortcut = None + if in_channels != out_channels: + self.shortcut = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + ) + + def forward(self, x, lengths=None): + """Processes the input tensor x and returns an output tensor.""" + residual = x + if self.shortcut: + residual = self.shortcut(x) + + x = self.tdnn1(x) + x = self.res2net_block(x) + x = self.tdnn2(x) + x = self.se_block(x, lengths) + + return x + residual + + +class ECAPA_TDNN(torch.nn.Module): + """An implementation of the speaker embedding model in a paper. + "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in + TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143). + + Arguments + --------- + input_size : int + Expected size of the input dimension. + device : str + Device used, e.g., "cpu" or "cuda". + lin_neurons : int + Number of neurons in linear layers. + activation : torch class + A class for constructing the activation layers. + channels : list of ints + Output channels for TDNN/SERes2Net layer. + kernel_sizes : list of ints + List of kernel sizes for each layer. + dilations : list of ints + List of dilations for kernels in each layer. + attention_channels: int + The number of attention channels. + res2net_scale : int + The scale of the Res2Net block. + se_channels : int + The number of output channels after squeeze. + global_context: bool + Whether to use global context. + groups : list of ints + List of groups for kernels in each layer. + + Example + ------- + >>> input_feats = torch.rand([5, 120, 80]) + >>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192) + >>> outputs = compute_embedding(input_feats) + >>> outputs.shape + torch.Size([5, 1, 192]) + """ + + def __init__( + self, + input_size, + device="cpu", + lin_neurons=192, + activation=torch.nn.ReLU, + channels=[512, 512, 512, 512, 1536], + kernel_sizes=[5, 3, 3, 3, 1], + dilations=[1, 2, 3, 4, 1], + attention_channels=128, + res2net_scale=8, + se_channels=128, + global_context=True, + groups=[1, 1, 1, 1, 1], + ): + super().__init__() + assert len(channels) == len(kernel_sizes) + assert len(channels) == len(dilations) + self.channels = channels + self.blocks = nn.ModuleList() + + # The initial TDNN layer + self.blocks.append( + TDNNBlock( + input_size, + channels[0], + kernel_sizes[0], + dilations[0], + activation, + groups[0], + ) + ) + + # SE-Res2Net layers + for i in range(1, len(channels) - 1): + self.blocks.append( + SERes2NetBlock( + channels[i - 1], + channels[i], + res2net_scale=res2net_scale, + se_channels=se_channels, + kernel_size=kernel_sizes[i], + dilation=dilations[i], + activation=activation, + groups=groups[i], + ) + ) + + # Multi-layer feature aggregation + self.mfa = TDNNBlock( + channels[-2] * (len(channels) - 2), + channels[-1], + kernel_sizes[-1], + dilations[-1], + activation, + groups=groups[-1], + ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + channels[-1], + attention_channels=attention_channels, + global_context=global_context, + ) + self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2) + + # Final linear transformation + self.fc = Conv1d( + in_channels=channels[-1] * 2, + out_channels=lin_neurons, + kernel_size=1, + ) + + def forward(self, x, lengths=None): + """Returns the embedding vector. + + Arguments + --------- + x : torch.Tensor + Tensor of shape (batch, time, channel). + lengths : torch.Tensor + Corresponding relative lengths of inputs. + + Returns + ------- + x : torch.Tensor + Embedding vector. + """ + # Minimize transpose for efficiency + x = x.transpose(1, 2) + + xl = [] + for layer in self.blocks: + try: + x = layer(x, lengths=lengths) + except TypeError: + x = layer(x) + xl.append(x) + + # Multi-layer feature aggregation + x = torch.cat(xl[1:], dim=1) + x = self.mfa(x) + + # Attentive Statistical Pooling + x = self.asp(x, lengths=lengths) + x = self.asp_bn(x) + + # Final linear transformation + x = self.fc(x) + + x = x.transpose(1, 2) + return x + + +class Classifier(torch.nn.Module): + """This class implements the cosine similarity on the top of features. + + Arguments + --------- + input_size : int + Expected size of input dimension. + device : str + Device used, e.g., "cpu" or "cuda". + lin_blocks : int + Number of linear layers. + lin_neurons : int + Number of neurons in linear layers. + out_neurons : int + Number of classes. + + Example + ------- + >>> classify = Classifier(input_size=2, lin_neurons=2, out_neurons=2) + >>> outputs = torch.tensor([ [1., -1.], [-9., 1.], [0.9, 0.1], [0.1, 0.9] ]) + >>> outputs = outputs.unsqueeze(1) + >>> cos = classify(outputs) + >>> (cos < -1.0).long().sum() + tensor(0) + >>> (cos > 1.0).long().sum() + tensor(0) + """ + + def __init__( + self, + input_size, + device="cpu", + lin_blocks=0, + lin_neurons=192, + out_neurons=1211, + ): + super().__init__() + self.blocks = nn.ModuleList() + + for block_index in range(lin_blocks): + self.blocks.extend( + [ + _BatchNorm1d(input_size=input_size), + Linear(input_size=input_size, n_neurons=lin_neurons), + ] + ) + input_size = lin_neurons + + # Final Layer + self.weight = nn.Parameter( + torch.FloatTensor(out_neurons, input_size, device=device) + ) + nn.init.xavier_uniform_(self.weight) + + def forward(self, x): + """Returns the output probabilities over speakers. + + Arguments + --------- + x : torch.Tensor + Torch tensor. + + Returns + ------- + out : torch.Tensor + Output probabilities over speakers. + """ + for layer in self.blocks: + x = layer(x) + + # Need to be normalized + x = F.linear(F.normalize(x.squeeze(1)), F.normalize(self.weight)) + return x.unsqueeze(1) diff --git a/indextts/BigVGAN/__init__.py b/indextts/BigVGAN/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indextts/BigVGAN/activations.py b/indextts/BigVGAN/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..1962c53b516b57fc4dd6b3d03c2c403565125153 --- /dev/null +++ b/indextts/BigVGAN/activations.py @@ -0,0 +1,122 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, pow, sin +from torch.nn import Parameter + + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x diff --git a/indextts/BigVGAN/alias_free_activation/.DS_Store b/indextts/BigVGAN/alias_free_activation/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..2221e3e3203bfd5e85cd941eb569ea342dcf801e Binary files /dev/null and b/indextts/BigVGAN/alias_free_activation/.DS_Store differ diff --git a/indextts/BigVGAN/alias_free_activation/__init__.py b/indextts/BigVGAN/alias_free_activation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indextts/BigVGAN/alias_free_activation/cuda/.gitignore b/indextts/BigVGAN/alias_free_activation/cuda/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..42afabfd2abebf31384ca7797186a27a4b7dbee8 --- /dev/null +++ b/indextts/BigVGAN/alias_free_activation/cuda/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/indextts/BigVGAN/alias_free_activation/cuda/__init__.py b/indextts/BigVGAN/alias_free_activation/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py b/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py new file mode 100644 index 0000000000000000000000000000000000000000..d05f1797680b8aba2729315fe49fe8db775f618e --- /dev/null +++ b/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +# load fused CUDA kernel: this enables importing anti_alias_activation_cuda +from indextts.BigVGAN.alias_free_activation.cuda import load +from indextts.BigVGAN.alias_free_activation.torch.resample import DownSample1d, UpSample1d + +anti_alias_activation_cuda = load.load() + + +class FusedAntiAliasActivation(torch.autograd.Function): + """ + Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs. + The hyperparameters are hard-coded in the kernel to maximize speed. + NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters. + """ + + @staticmethod + def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta): + activation_results = anti_alias_activation_cuda.forward( + inputs, up_ftr, down_ftr, alpha, beta + ) + + return activation_results + + @staticmethod + def backward(ctx, output_grads): + raise NotImplementedError + return output_grads, None, None + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + fused: bool = True, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + self.fused = fused # Whether to use fused CUDA kernel or not + + def forward(self, x): + if not self.fused: + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + else: + if self.act.__class__.__name__ == "Snake": + beta = self.act.alpha.data # Snake uses same params for alpha and beta + else: + beta = ( + self.act.beta.data + ) # Snakebeta uses different params for alpha and beta + alpha = self.act.alpha.data + if ( + not self.act.alpha_logscale + ): # Exp baked into cuda kernel, cancel it out with a log + alpha = torch.log(alpha) + beta = torch.log(beta) + + x = FusedAntiAliasActivation.apply( + x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta + ) + return x diff --git a/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp b/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c5651f77143bd678169eb11564a7cf7a7969a59e --- /dev/null +++ b/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp @@ -0,0 +1,23 @@ +/* coding=utf-8 + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #include + +extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)"); +} \ No newline at end of file diff --git a/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu b/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..a36d9177d882caa3812240fe472c88538aa8d067 --- /dev/null +++ b/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu @@ -0,0 +1,256 @@ +/* coding=utf-8 + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "type_shim.h" +#include +#include +#include +#include +#include + +namespace +{ + // Hard-coded hyperparameters + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4; + constexpr int BUFFER_SIZE = 32; + constexpr int FILTER_SIZE = 12; + constexpr int HALF_FILTER_SIZE = 6; + constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl + constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl + constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl + + template + __global__ void anti_alias_activation_forward( + output_t *dst, + const input_t *src, + const acc_t *up_ftr, + const acc_t *down_ftr, + const acc_t *alpha, + const acc_t *beta, + int batch_size, + int channels, + int seq_len) + { + // Up and downsample filters + input_t up_filter[FILTER_SIZE]; + input_t down_filter[FILTER_SIZE]; + + // Load data from global memory including extra indices reserved for replication paddings + input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0}; + input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0}; + + // Output stores downsampled output before writing to dst + output_t output[BUFFER_SIZE]; + + // blockDim/threadIdx = (128, 1, 1) + // gridDim/blockIdx = (seq_blocks, channels, batches) + int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); + int local_offset = threadIdx.x * BUFFER_SIZE; + int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset; + + // intermediate have double the seq_len + int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2; + int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset; + + // Get values needed for replication padding before moving pointer + const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); + input_t seq_left_most_value = right_most_pntr[0]; + input_t seq_right_most_value = right_most_pntr[seq_len - 1]; + + // Move src and dst pointers + src += block_offset + local_offset; + dst += block_offset + local_offset; + + // Alpha and beta values for snake activatons. Applies exp by default + alpha = alpha + blockIdx.y; + beta = beta + blockIdx.y; + + acc_t alpha_val = expf(alpha[0]); + acc_t beta_val = expf(beta[0]); + + #pragma unroll + for (int it = 0; it < FILTER_SIZE; it += 1) + { + up_filter[it] = up_ftr[it]; + down_filter[it] = down_ftr[it]; + } + + // Apply replication padding for upsampling, matching torch impl + #pragma unroll + for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1) + { + int element_index = seq_offset + it; // index for element + if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD)) + { + elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value; + } + if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD)) + { + elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value; + } + if ((element_index >= 0) && (element_index < seq_len)) + { + elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it]; + } + } + + // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later + #pragma unroll + for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1) + { + acc_t acc = 0.0; + int element_index = intermediate_seq_offset + it; // index for intermediate + #pragma unroll + for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) + { + if ((element_index + f_idx) >= 0) + { + acc += up_filter[f_idx] * elements[it + f_idx]; + } + } + intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc; + } + + // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later + double no_div_by_zero = 0.000000001; + #pragma unroll + for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1) + { + acc_t a = sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val); + intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * a * a; + } + + // Apply replication padding before downsampling conv from intermediates + #pragma unroll + for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1) + { + intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT]; + } + #pragma unroll + for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1) + { + intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1]; + } + + // Apply downsample strided convolution (assuming stride=2) from intermediates + #pragma unroll + for (int it = 0; it < BUFFER_SIZE; it += 1) + { + acc_t acc = 0.0; + #pragma unroll + for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) + { + // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation + acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT]; + } + output[it] = acc; + } + + // Write output to dst + #pragma unroll + for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG) + { + int element_index = seq_offset + it; + if (element_index < seq_len) + { + dst[it] = output[it]; + } + } + + } + + template + void dispatch_anti_alias_activation_forward( + output_t *dst, + const input_t *src, + const acc_t *up_ftr, + const acc_t *down_ftr, + const acc_t *alpha, + const acc_t *beta, + int batch_size, + int channels, + int seq_len) + { + if (seq_len == 0) + { + return; + } + else + { + // Use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + constexpr int seq_len_per_block = 4096; + int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block; + dim3 blocks(blocks_per_seq_len, channels, batch_size); + dim3 threads(threads_per_block, 1, 1); + + anti_alias_activation_forward + <<>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len); + } + } +} + +extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta) +{ + // Input is a 3d tensor with dimensions [batches, channels, seq_len] + const int batches = input.size(0); + const int channels = input.size(1); + const int seq_len = input.size(2); + + // Output + auto act_options = input.options().requires_grad(false); + + torch::Tensor anti_alias_activation_results = + torch::empty({batches, channels, seq_len}, act_options); + + using float32 = float; + // The dtype of input is float16, bfloat16, or float32 + // The dtype of up_filter, down_filter, alpha, and beta is float32 + // printf("input scalar type: %d\n", input.scalar_type()); + // printf("up_filter scalar type: %d\n", up_filter.scalar_type()); + // printf("down_filter scalar type: %d\n", down_filter.scalar_type()); + // printf("alpha scalar type: %d\n", alpha.scalar_type()); + // printf("beta scalar type: %d\n", beta.scalar_type()); + void *input_ptr = static_cast(input.data_ptr()); + float32 *up_filter_ptr = static_cast(up_filter.data_ptr()); + float32 *down_filter_ptr = static_cast(down_filter.data_ptr()); + float32 *alpha_ptr = static_cast(alpha.data_ptr()); + float32 *beta_ptr = static_cast(beta.data_ptr()); + void *anti_alias_activation_results_ptr = static_cast(anti_alias_activation_results.data_ptr()); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch anti alias activation_forward", + dispatch_anti_alias_activation_forward( + reinterpret_cast(anti_alias_activation_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(up_filter_ptr), + reinterpret_cast(down_filter_ptr), + reinterpret_cast(alpha_ptr), + reinterpret_cast(beta_ptr), + batches, + channels, + seq_len);); + return anti_alias_activation_results; +} \ No newline at end of file diff --git a/indextts/BigVGAN/alias_free_activation/cuda/compat.h b/indextts/BigVGAN/alias_free_activation/cuda/compat.h new file mode 100644 index 0000000000000000000000000000000000000000..25818b2edf4cb0dc9130e62c7c4de8d16a01baa5 --- /dev/null +++ b/indextts/BigVGAN/alias_free_activation/cuda/compat.h @@ -0,0 +1,29 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/indextts/BigVGAN/alias_free_activation/cuda/load.py b/indextts/BigVGAN/alias_free_activation/cuda/load.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ff4d8261b59e4c0d11c67ecd7ff0aebb45f69c --- /dev/null +++ b/indextts/BigVGAN/alias_free_activation/cuda/load.py @@ -0,0 +1,121 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import os +import pathlib +import subprocess + +from torch.utils import cpp_extension + +""" +Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels. +Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below +""" +os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + +import re +import shutil +import tempfile + +# 补丁修复:sources 路径含中文字符时,生成 build.ninja 乱码导致编译失败 +# 使用临时目录来规避 ninja 编译失败(比如中文路径) +def chinese_path_compile_support(sources, buildpath): + pattern = re.compile(r'[\u4e00-\u9fff]') + if not bool(pattern.search(str(sources[0].resolve()))): + return buildpath # 检测非中文路径跳过 + # Create build directory + resolves = [ item.name for item in sources] + ninja_compile_dir = os.path.join(tempfile.gettempdir(), "BigVGAN", "cuda") + os.makedirs(ninja_compile_dir, exist_ok=True) + new_buildpath = os.path.join(ninja_compile_dir, "build") + os.makedirs(new_buildpath, exist_ok=True) + print(f"ninja_buildpath: {new_buildpath}") + # Copy files to directory + sources.clear() + current_dir = os.path.dirname(__file__) + ALLOWED_EXTENSIONS = {'.py', '.cu', '.cpp', '.h'} + for filename in os.listdir(current_dir): + item = pathlib.Path(current_dir).joinpath(filename) + tar_path = pathlib.Path(ninja_compile_dir).joinpath(item.name) + if not item.suffix.lower() in ALLOWED_EXTENSIONS:continue + pathlib.Path(shutil.copy2(item, tar_path)) + if tar_path.name in resolves:sources.append(tar_path) + return new_buildpath + + + +def load(): + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) + if int(bare_metal_major) >= 11: + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,code=sm_80") + + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / "build" + _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=[ + "-O3", + ], + extra_cuda_cflags=[ + "-O3", + "-gencode", + "arch=compute_70,code=sm_70", + "--use_fast_math", + ] + + extra_cuda_flags + + cc_flag, + verbose=True, + ) + + extra_cuda_flags = [ + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ] + + sources = [ + srcpath / "anti_alias_activation.cpp", + srcpath / "anti_alias_activation_cuda.cu", + ] + + # 兼容方案:ninja 特殊字符路径编译支持处理(比如中文路径) + buildpath = chinese_path_compile_support(sources, buildpath) + + anti_alias_activation_cuda = _cpp_extention_load_helper( + "anti_alias_activation_cuda", sources, extra_cuda_flags + ) + + return anti_alias_activation_cuda + + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") diff --git a/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h b/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h new file mode 100644 index 0000000000000000000000000000000000000000..5db7e8a397e982d4d30d16ab6060814b98b7ab83 --- /dev/null +++ b/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h @@ -0,0 +1,92 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "compat.h" + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch (TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch (TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } diff --git a/indextts/BigVGAN/alias_free_activation/torch/__init__.py b/indextts/BigVGAN/alias_free_activation/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..117e5ac84bacddff935119bc55c5892beea37cfa --- /dev/null +++ b/indextts/BigVGAN/alias_free_activation/torch/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .act import * +from .filter import * +from .resample import * diff --git a/indextts/BigVGAN/alias_free_activation/torch/act.py b/indextts/BigVGAN/alias_free_activation/torch/act.py new file mode 100644 index 0000000000000000000000000000000000000000..d46808d36aaf682f8cd7a6501d0217718b962d9d --- /dev/null +++ b/indextts/BigVGAN/alias_free_activation/torch/act.py @@ -0,0 +1,31 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn + +from .resample import DownSample1d, UpSample1d + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x diff --git a/indextts/BigVGAN/alias_free_activation/torch/filter.py b/indextts/BigVGAN/alias_free_activation/torch/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..172cfc656b81347b25c94a12610371997fade011 --- /dev/null +++ b/indextts/BigVGAN/alias_free_activation/torch/filter.py @@ -0,0 +1,102 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +if "sinc" in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d( + cutoff, half_width, kernel_size +): # return filter [1,1,kernel_size] + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + """ + Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal. + """ + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = "replicate", + kernel_size: int = 12, + ): + """ + kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible. + """ + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + # Input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + return out diff --git a/indextts/BigVGAN/alias_free_activation/torch/resample.py b/indextts/BigVGAN/alias_free_activation/torch/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..46c47705469b5d33bb8561560a440c1e36f9e79f --- /dev/null +++ b/indextts/BigVGAN/alias_free_activation/torch/resample.py @@ -0,0 +1,58 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F + +from .filter import LowPassFilter1d, kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + xx = self.lowpass(x) + + return xx diff --git a/indextts/BigVGAN/alias_free_torch/__init__.py b/indextts/BigVGAN/alias_free_torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..117e5ac84bacddff935119bc55c5892beea37cfa --- /dev/null +++ b/indextts/BigVGAN/alias_free_torch/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .act import * +from .filter import * +from .resample import * diff --git a/indextts/BigVGAN/alias_free_torch/act.py b/indextts/BigVGAN/alias_free_torch/act.py new file mode 100644 index 0000000000000000000000000000000000000000..e6798bf8c64e5c91a48d4dad86d4bf58da29593c --- /dev/null +++ b/indextts/BigVGAN/alias_free_torch/act.py @@ -0,0 +1,29 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn + +from .resample import DownSample1d, UpSample1d + + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x \ No newline at end of file diff --git a/indextts/BigVGAN/alias_free_torch/filter.py b/indextts/BigVGAN/alias_free_torch/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..2a90bdab31738bde1bbaf2a1acc022cc177139be --- /dev/null +++ b/indextts/BigVGAN/alias_free_torch/filter.py @@ -0,0 +1,96 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + + return out \ No newline at end of file diff --git a/indextts/BigVGAN/alias_free_torch/resample.py b/indextts/BigVGAN/alias_free_torch/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf4d5466e7e1d6b0723e6dec7a1e287e6be6eb9 --- /dev/null +++ b/indextts/BigVGAN/alias_free_torch/resample.py @@ -0,0 +1,49 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F + +from .filter import LowPassFilter1d, kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx \ No newline at end of file diff --git a/indextts/BigVGAN/bigvgan.py b/indextts/BigVGAN/bigvgan.py new file mode 100644 index 0000000000000000000000000000000000000000..58e4ba18f6e4cdb39516c905fe72bf360d026137 --- /dev/null +++ b/indextts/BigVGAN/bigvgan.py @@ -0,0 +1,534 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import json +import os +from pathlib import Path +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, weight_norm + +import indextts.BigVGAN.activations as activations +from indextts.BigVGAN.alias_free_activation.torch.act import \ + Activation1d as TorchActivation1d +from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN +from indextts.BigVGAN.env import AttrDict +from indextts.BigVGAN.utils import get_padding, init_weights + + +def load_hparams_from_json(path) -> AttrDict: + with open(path) as f: + data = f.read() + return AttrDict(json.loads(data)) + + +class AMPBlock1(torch.nn.Module): + """ + AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer. + AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1 + + Args: + h (AttrDict): Hyperparameters. + channels (int): Number of convolution channels. + kernel_size (int): Size of the convolution kernel. Default is 3. + dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5). + activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None. + """ + + def __init__( + self, + h: AttrDict, + channels: int, + kernel_size: int = 3, + dilation: tuple = (1, 3, 5), + activation: str = None, + ): + super().__init__() + + self.h = h + + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=d, + padding=get_padding(kernel_size, d), + ) + ) + for d in dilation + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ) + for _ in range(len(dilation)) + ] + ) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len( + self.convs2 + ) # Total number of conv layers + + # Select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + from alias_free_activation.cuda.activation1d import \ + Activation1d as CudaActivation1d + + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + + # Activation functions + if activation == "snake": + self.activations = nn.ModuleList( + [ + Activation1d( + activation=activations.Snake( + channels, alpha_logscale=h.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + elif activation == "snakebeta": + self.activations = nn.ModuleList( + [ + Activation1d( + activation=activations.SnakeBeta( + channels, alpha_logscale=h.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class AMPBlock2(torch.nn.Module): + """ + AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer. + Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1 + + Args: + h (AttrDict): Hyperparameters. + channels (int): Number of convolution channels. + kernel_size (int): Size of the convolution kernel. Default is 3. + dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5). + activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None. + """ + + def __init__( + self, + h: AttrDict, + channels: int, + kernel_size: int = 3, + dilation: tuple = (1, 3, 5), + activation: str = None, + ): + super().__init__() + + self.h = h + + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=d, + padding=get_padding(kernel_size, d), + ) + ) + for d in dilation + ] + ) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) # Total number of conv layers + + # Select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + from alias_free_activation.cuda.activation1d import \ + Activation1d as CudaActivation1d + + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + + # Activation functions + if activation == "snake": + self.activations = nn.ModuleList( + [ + Activation1d( + activation=activations.Snake( + channels, alpha_logscale=h.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + elif activation == "snakebeta": + self.activations = nn.ModuleList( + [ + Activation1d( + activation=activations.SnakeBeta( + channels, alpha_logscale=h.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + for c, a in zip(self.convs, self.activations): + xt = a(x) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +''' + PyTorchModelHubMixin, + library_name="bigvgan", + repo_url="https://github.com/NVIDIA/BigVGAN", + docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md", + pipeline_tag="audio-to-audio", + license="mit", + tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"], +''' + + +class BigVGAN( + torch.nn.Module, +): + """ + BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks). + New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks. + + Args: + h (AttrDict): Hyperparameters. + use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels. + + Note: + - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported. + - Ensure that the activation function is correctly specified in the hyperparameters (h.activation). + """ + + def __init__(self, h: AttrDict, use_cuda_kernel: bool = False): + super().__init__() + self.h = h + self.h["use_cuda_kernel"] = use_cuda_kernel + + # Select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + from alias_free_activation.cuda.activation1d import \ + Activation1d as CudaActivation1d + + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + + self.feat_upsample = h.feat_upsample + self.cond_in_each_up_layer = h.cond_d_vector_in_each_upsampling_layer + + # Pre-conv + self.conv_pre = weight_norm( + Conv1d(h.gpt_dim, h.upsample_initial_channel, 7, 1, padding=3) + ) + + # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + if h.resblock == "1": + resblock_class = AMPBlock1 + elif h.resblock == "2": + resblock_class = AMPBlock2 + else: + raise ValueError( + f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}" + ) + + # Transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + nn.ModuleList( + [ + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ] + ) + ) + + # Residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append( + resblock_class(h, ch, k, d, activation=h.activation) + ) + + # Post-conv + activation_post = ( + activations.Snake(ch, alpha_logscale=h.snake_logscale) + if h.activation == "snake" + else ( + activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) + if h.activation == "snakebeta" + else None + ) + ) + if activation_post is None: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + self.activation_post = Activation1d(activation=activation_post) + + # Whether to use bias for the final conv_post. Default to True for backward compatibility + self.use_bias_at_final = h.get("use_bias_at_final", True) + self.conv_post = weight_norm( + Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final) + ) + + # Weight initialization + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + # Final tanh activation. Defaults to True for backward compatibility + self.use_tanh_at_final = h.get("use_tanh_at_final", True) + + self.speaker_encoder = ECAPA_TDNN(h.num_mels, lin_neurons=h.speaker_embedding_dim) + self.cond_layer = nn.Conv1d(h.speaker_embedding_dim, h.upsample_initial_channel, 1) + if self.cond_in_each_up_layer: + self.conds = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + self.conds.append(nn.Conv1d(h.speaker_embedding_dim, ch, 1)) + + def forward(self, x, mel_refer, lens=None): + # Speaker reference + speaker_embedding = self.speaker_encoder(mel_refer, lens) + n_batch = x.size(0) + contrastive_loss = None + if n_batch * 2 == speaker_embedding.size(0): + spe_emb_chunk1, spe_emb_chunk2 = speaker_embedding[:n_batch, :, :], speaker_embedding[n_batch:, :, :] + contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1), + self.logit_scale.exp()) + + speaker_embedding = speaker_embedding[:n_batch, :, :] + speaker_embedding = speaker_embedding.transpose(1, 2) + + # upsample feat + if self.feat_upsample: + x = torch.nn.functional.interpolate( + x.transpose(1, 2), + scale_factor=[4], + mode="linear", + ).squeeze(1) + else: + x = x.transpose(1, 2) + + # BigVGAN + # Pre-conv + x = self.conv_pre(x) + x = x + self.cond_layer(speaker_embedding) + + for i in range(self.num_upsamples): + # Upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + + if self.cond_in_each_up_layer: + x = x + self.conds[i](speaker_embedding) + + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # Post-conv + x = self.activation_post(x) + x = self.conv_post(x) + # Final tanh activation + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1] + + return x, contrastive_loss + + def remove_weight_norm(self): + try: + print("Removing weight norm...") + for l in self.ups: + for l_i in l: + remove_weight_norm(l_i) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + except ValueError: + print("[INFO] Model already removed weight norm. Skipping!") + pass + + # Additional methods for huggingface_hub support + def _save_pretrained(self, save_directory: Path) -> None: + """Save weights and config.json from a Pytorch model to a local directory.""" + + model_path = save_directory / "bigvgan_generator.pt" + torch.save({"generator": self.state_dict()}, model_path) + + config_path = save_directory / "config.json" + with open(config_path, "w") as config_file: + json.dump(self.h, config_file, indent=4) + + @classmethod + def _from_pretrained( + cls, + *, + model_id: str, + revision: str, + cache_dir: str, + force_download: bool, + proxies: Optional[Dict], + resume_download: bool, + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", # Additional argument + strict: bool = False, # Additional argument + use_cuda_kernel: bool = False, + **model_kwargs, + ): + """Load Pytorch pretrained weights and return the loaded model.""" + + # Download and load hyperparameters (h) used by BigVGAN + if os.path.isdir(model_id): + print("Loading config.json from local directory") + config_file = os.path.join(model_id, "config.json") + else: + config_file = hf_hub_download( + repo_id=model_id, + filename="config.json", + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + h = load_hparams_from_json(config_file) + + # instantiate BigVGAN using h + if use_cuda_kernel: + print( + f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!" + ) + print( + f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!" + ) + print( + f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis" + ) + model = cls(h, use_cuda_kernel=use_cuda_kernel) + + # Download and load pretrained generator weight + if os.path.isdir(model_id): + print("Loading weights from local directory") + model_file = os.path.join(model_id, "bigvgan_generator.pt") + else: + print(f"Loading weights from {model_id}") + model_file = hf_hub_download( + repo_id=model_id, + filename="bigvgan_generator.pt", + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + + checkpoint_dict = torch.load(model_file, map_location=map_location) + + try: + model.load_state_dict(checkpoint_dict["generator"]) + except RuntimeError: + print( + f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!" + ) + model.remove_weight_norm() + model.load_state_dict(checkpoint_dict["generator"]) + + return model diff --git a/indextts/BigVGAN/models.py b/indextts/BigVGAN/models.py new file mode 100644 index 0000000000000000000000000000000000000000..771b89bc2c8d0ddca994ce8e68d16534cc0e1d85 --- /dev/null +++ b/indextts/BigVGAN/models.py @@ -0,0 +1,451 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +import indextts.BigVGAN.activations as activations + +from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN +from indextts.BigVGAN.utils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +class AMPBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): + super(AMPBlock1, self).__init__() + self.h = h + + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + if self.h.get("use_cuda_kernel", False): + from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d + else: + from indextts.BigVGAN.alias_free_torch import Activation1d + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class AMPBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None): + super(AMPBlock2, self).__init__() + self.h = h + + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) # total number of conv layers + if self.h.get("use_cuda_kernel", False): + from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d + else: + from indextts.BigVGAN.alias_free_torch import Activation1d + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") + + def forward(self, x): + for c, a in zip(self.convs, self.activations): + xt = a(x) + xt = c(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class BigVGAN(torch.nn.Module): + # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. + def __init__(self, h, use_cuda_kernel=False): + """ + Args: + h (dict) + use_cuda_kernel (bool): whether to use custom cuda kernel for anti-aliased activation + """ + super(BigVGAN, self).__init__() + self.h = h + self.h["use_cuda_kernel"] = use_cuda_kernel + + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + + self.feat_upsample = h.feat_upsample + self.cond_in_each_up_layer = h.cond_d_vector_in_each_upsampling_layer + + # pre conv + self.conv_pre = weight_norm(Conv1d(h.gpt_dim, h.upsample_initial_channel, 7, 1, padding=3)) + + # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + resblock = AMPBlock1 if h.resblock == "1" else AMPBlock2 + + # transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append(nn.ModuleList([ + weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, u, padding=(k - u) // 2)) + ])) + + # residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(self.h, ch, k, d, activation=h.activation)) + if use_cuda_kernel: + from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d + else: + from indextts.BigVGAN.alias_free_torch import Activation1d + + # post conv + if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing + activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing + activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + else: + raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + + # weight initialization + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + self.speaker_encoder = ECAPA_TDNN(h.num_mels, lin_neurons=h.speaker_embedding_dim) + self.cond_layer = nn.Conv1d(h.speaker_embedding_dim, h.upsample_initial_channel, 1) + if self.cond_in_each_up_layer: + self.conds = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + self.conds.append(nn.Conv1d(h.speaker_embedding_dim, ch, 1)) + + # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def forward(self, x, mel_ref, lens=None): + speaker_embedding = self.speaker_encoder(mel_ref, lens) + n_batch = x.size(0) + contrastive_loss = None + if n_batch * 2 == speaker_embedding.size(0): + spe_emb_chunk1, spe_emb_chunk2 = speaker_embedding[:n_batch, :, :], speaker_embedding[n_batch:, :, :] + contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1), self.logit_scale.exp()) + + speaker_embedding = speaker_embedding[:n_batch, :, :] + speaker_embedding = speaker_embedding.transpose(1, 2) + + # upsample feat + if self.feat_upsample: + x = torch.nn.functional.interpolate( + x.transpose(1, 2), + scale_factor=[4], + mode="linear", + ).squeeze(1) + else: + x = x.transpose(1, 2) + + ### bigVGAN ### + # pre conv + x = self.conv_pre(x) + + x = x + self.cond_layer(speaker_embedding) + + for i in range(self.num_upsamples): + # upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + + if self.cond_in_each_up_layer: + x = x + self.conds[i](speaker_embedding) + + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # post conv + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x, contrastive_loss + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + for l_i in l: + remove_weight_norm(l_i) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + def cal_clip_loss(self, image_features, text_features, logit_scale): + device = image_features.device + logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) + labels = torch.arange(logits_per_image.shape[0], device=device, dtype=torch.long) + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + return total_loss + + def get_logits(self, image_features, text_features, logit_scale): + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + return logits_per_image, logits_per_text + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, h, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.d_mult = h.discriminator_channel_mult + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, int(32 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(32 * self.d_mult), int(128 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(128 * self.d_mult), int(512 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(512 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(1024 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, h): + super(MultiPeriodDiscriminator, self).__init__() + self.mpd_reshapes = h.mpd_reshapes + print("mpd_reshapes: {}".format(self.mpd_reshapes)) + discriminators = [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes] + self.discriminators = nn.ModuleList(discriminators) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__(self, cfg, resolution): + super().__init__() + + self.resolution = resolution + assert len(self.resolution) == 3, \ + "MRD layer requires list with len=3, got {}".format(self.resolution) + self.lrelu_slope = LRELU_SLOPE + + norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm + if hasattr(cfg, "mrd_use_spectral_norm"): + print("INFO: overriding MRD use_spectral_norm as {}".format(cfg.mrd_use_spectral_norm)) + norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm + self.d_mult = cfg.discriminator_channel_mult + if hasattr(cfg, "mrd_channel_mult"): + print("INFO: overriding mrd channel multiplier as {}".format(cfg.mrd_channel_mult)) + self.d_mult = cfg.mrd_channel_mult + + self.convs = nn.ModuleList([ + norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 3), padding=(1, 1))), + ]) + self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))) + + def forward(self, x): + fmap = [] + + x = self.spectrogram(x) + x = x.unsqueeze(1) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + def spectrogram(self, x): + n_fft, hop_length, win_length = self.resolution + x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect') + x = x.squeeze(1) + x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True) + x = torch.view_as_real(x) # [B, F, TT, 2] + mag = torch.norm(x, p=2, dim=-1) # [B, F, TT] + + return mag + + +class MultiResolutionDiscriminator(nn.Module): + def __init__(self, cfg, debug=False): + super().__init__() + self.resolutions = cfg.resolutions + assert len(self.resolutions) == 3, \ + "MRD requires list of list with len=3, each element having a list with len=3. got {}".\ + format(self.resolutions) + self.discriminators = nn.ModuleList( + [DiscriminatorR(cfg, resolution) for resolution in self.resolutions] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(x=y) + y_d_g, fmap_g = d(x=y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr)**2) + g_loss = torch.mean(dg**2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg)**2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/indextts/BigVGAN/nnet/CNN.py b/indextts/BigVGAN/nnet/CNN.py new file mode 100644 index 0000000000000000000000000000000000000000..fa79dc14e2112350d39f94b5fb45ff148f0fb7ec --- /dev/null +++ b/indextts/BigVGAN/nnet/CNN.py @@ -0,0 +1,546 @@ +"""Library implementing convolutional neural networks. + +Authors + * Mirco Ravanelli 2020 + * Jianyuan Zhong 2020 + * Cem Subakan 2021 + * Davide Borra 2021 + * Andreas Nautsch 2022 + * Sarthak Yadav 2022 +""" + +import logging +import math +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + + +class SincConv(nn.Module): + """This function implements SincConv (SincNet). + + M. Ravanelli, Y. Bengio, "Speaker Recognition from raw waveform with + SincNet", in Proc. of SLT 2018 (https://arxiv.org/abs/1808.00158) + + Arguments + --------- + out_channels : int + It is the number of output channels. + kernel_size: int + Kernel size of the convolutional filters. + input_shape : tuple + The shape of the input. Alternatively use ``in_channels``. + in_channels : int + The number of input channels. Alternatively use ``input_shape``. + stride : int + Stride factor of the convolutional filters. When the stride factor > 1, + a decimation in time is performed. + dilation : int + Dilation factor of the convolutional filters. + padding : str + (same, valid, causal). If "valid", no padding is performed. + If "same" and stride is 1, output shape is the same as the input shape. + "causal" results in causal (dilated) convolutions. + padding_mode : str + This flag specifies the type of padding. See torch.nn documentation + for more information. + sample_rate : int + Sampling rate of the input signals. It is only used for sinc_conv. + min_low_hz : float + Lowest possible frequency (in Hz) for a filter. It is only used for + sinc_conv. + min_band_hz : float + Lowest possible value (in Hz) for a filter bandwidth. + + Example + ------- + >>> inp_tensor = torch.rand([10, 16000]) + >>> conv = SincConv(input_shape=inp_tensor.shape, out_channels=25, kernel_size=11) + >>> out_tensor = conv(inp_tensor) + >>> out_tensor.shape + torch.Size([10, 16000, 25]) + """ + + def __init__( + self, + out_channels, + kernel_size, + input_shape=None, + in_channels=None, + stride=1, + dilation=1, + padding="same", + padding_mode="reflect", + sample_rate=16000, + min_low_hz=50, + min_band_hz=50, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.padding = padding + self.padding_mode = padding_mode + self.sample_rate = sample_rate + self.min_low_hz = min_low_hz + self.min_band_hz = min_band_hz + + # input shape inference + if input_shape is None and self.in_channels is None: + raise ValueError("Must provide one of input_shape or in_channels") + + if self.in_channels is None: + self.in_channels = self._check_input_shape(input_shape) + + if self.out_channels % self.in_channels != 0: + raise ValueError( + "Number of output channels must be divisible by in_channels" + ) + + # Initialize Sinc filters + self._init_sinc_conv() + + def forward(self, x): + """Returns the output of the convolution. + + Arguments + --------- + x : torch.Tensor (batch, time, channel) + input to convolve. 2d or 4d tensors are expected. + + Returns + ------- + wx : torch.Tensor + The convolved outputs. + """ + x = x.transpose(1, -1) + self.device = x.device + + unsqueeze = x.ndim == 2 + if unsqueeze: + x = x.unsqueeze(1) + + if self.padding == "same": + x = self._manage_padding( + x, self.kernel_size, self.dilation, self.stride + ) + + elif self.padding == "causal": + num_pad = (self.kernel_size - 1) * self.dilation + x = F.pad(x, (num_pad, 0)) + + elif self.padding == "valid": + pass + + else: + raise ValueError( + "Padding must be 'same', 'valid' or 'causal'. Got %s." + % (self.padding) + ) + + sinc_filters = self._get_sinc_filters() + + wx = F.conv1d( + x, + sinc_filters, + stride=self.stride, + padding=0, + dilation=self.dilation, + groups=self.in_channels, + ) + + if unsqueeze: + wx = wx.squeeze(1) + + wx = wx.transpose(1, -1) + + return wx + + def _check_input_shape(self, shape): + """Checks the input shape and returns the number of input channels.""" + + if len(shape) == 2: + in_channels = 1 + elif len(shape) == 3: + in_channels = shape[-1] + else: + raise ValueError( + "sincconv expects 2d or 3d inputs. Got " + str(len(shape)) + ) + + # Kernel size must be odd + if self.kernel_size % 2 == 0: + raise ValueError( + "The field kernel size must be an odd number. Got %s." + % (self.kernel_size) + ) + return in_channels + + def _get_sinc_filters(self): + """This functions creates the sinc-filters to used for sinc-conv.""" + # Computing the low frequencies of the filters + low = self.min_low_hz + torch.abs(self.low_hz_) + + # Setting minimum band and minimum freq + high = torch.clamp( + low + self.min_band_hz + torch.abs(self.band_hz_), + self.min_low_hz, + self.sample_rate / 2, + ) + band = (high - low)[:, 0] + + # Passing from n_ to the corresponding f_times_t domain + self.n_ = self.n_.to(self.device) + self.window_ = self.window_.to(self.device) + f_times_t_low = torch.matmul(low, self.n_) + f_times_t_high = torch.matmul(high, self.n_) + + # Left part of the filters. + band_pass_left = ( + (torch.sin(f_times_t_high) - torch.sin(f_times_t_low)) + / (self.n_ / 2) + ) * self.window_ + + # Central element of the filter + band_pass_center = 2 * band.view(-1, 1) + + # Right part of the filter (sinc filters are symmetric) + band_pass_right = torch.flip(band_pass_left, dims=[1]) + + # Combining left, central, and right part of the filter + band_pass = torch.cat( + [band_pass_left, band_pass_center, band_pass_right], dim=1 + ) + + # Amplitude normalization + band_pass = band_pass / (2 * band[:, None]) + + # Setting up the filter coefficients + filters = band_pass.view(self.out_channels, 1, self.kernel_size) + + return filters + + def _init_sinc_conv(self): + """Initializes the parameters of the sinc_conv layer.""" + + # Initialize filterbanks such that they are equally spaced in Mel scale + high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz) + + mel = torch.linspace( + self._to_mel(self.min_low_hz), + self._to_mel(high_hz), + self.out_channels + 1, + ) + + hz = self._to_hz(mel) + + # Filter lower frequency and bands + self.low_hz_ = hz[:-1].unsqueeze(1) + self.band_hz_ = (hz[1:] - hz[:-1]).unsqueeze(1) + + # Maiking freq and bands learnable + self.low_hz_ = nn.Parameter(self.low_hz_) + self.band_hz_ = nn.Parameter(self.band_hz_) + + # Hamming window + n_lin = torch.linspace( + 0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2)) + ) + self.window_ = 0.54 - 0.46 * torch.cos( + 2 * math.pi * n_lin / self.kernel_size + ) + + # Time axis (only half is needed due to symmetry) + n = (self.kernel_size - 1) / 2.0 + self.n_ = ( + 2 * math.pi * torch.arange(-n, 0).view(1, -1) / self.sample_rate + ) + + def _to_mel(self, hz): + """Converts frequency in Hz to the mel scale.""" + return 2595 * np.log10(1 + hz / 700) + + def _to_hz(self, mel): + """Converts frequency in the mel scale to Hz.""" + return 700 * (10 ** (mel / 2595) - 1) + + def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): + """This function performs zero-padding on the time axis + such that their lengths is unchanged after the convolution. + + Arguments + --------- + x : torch.Tensor + Input tensor. + kernel_size : int + Size of kernel. + dilation : int + Dilation used. + stride : int + Stride. + + Returns + ------- + x : torch.Tensor + """ + + # Detecting input shape + L_in = self.in_channels + + # Time padding + padding = get_padding_elem(L_in, stride, kernel_size, dilation) + + # Applying padding + x = F.pad(x, padding, mode=self.padding_mode) + + return x + + +class Conv1d(nn.Module): + """This function implements 1d convolution. + + Arguments + --------- + out_channels : int + It is the number of output channels. + kernel_size : int + Kernel size of the convolutional filters. + input_shape : tuple + The shape of the input. Alternatively use ``in_channels``. + in_channels : int + The number of input channels. Alternatively use ``input_shape``. + stride : int + Stride factor of the convolutional filters. When the stride factor > 1, + a decimation in time is performed. + dilation : int + Dilation factor of the convolutional filters. + padding : str + (same, valid, causal). If "valid", no padding is performed. + If "same" and stride is 1, output shape is the same as the input shape. + "causal" results in causal (dilated) convolutions. + groups : int + Number of blocked connections from input channels to output channels. + bias : bool + Whether to add a bias term to convolution operation. + padding_mode : str + This flag specifies the type of padding. See torch.nn documentation + for more information. + skip_transpose : bool + If False, uses batch x time x channel convention of speechbrain. + If True, uses batch x channel x time convention. + weight_norm : bool + If True, use weight normalization, + to be removed with self.remove_weight_norm() at inference + conv_init : str + Weight initialization for the convolution network + default_padding: str or int + This sets the default padding mode that will be used by the pytorch Conv1d backend. + + Example + ------- + >>> inp_tensor = torch.rand([10, 40, 16]) + >>> cnn_1d = Conv1d( + ... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5 + ... ) + >>> out_tensor = cnn_1d(inp_tensor) + >>> out_tensor.shape + torch.Size([10, 40, 8]) + """ + + def __init__( + self, + out_channels, + kernel_size, + input_shape=None, + in_channels=None, + stride=1, + dilation=1, + padding="same", + groups=1, + bias=True, + padding_mode="reflect", + skip_transpose=False, + weight_norm=False, + conv_init=None, + default_padding=0, + ): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.padding = padding + self.padding_mode = padding_mode + self.unsqueeze = False + self.skip_transpose = skip_transpose + + if input_shape is None and in_channels is None: + raise ValueError("Must provide one of input_shape or in_channels") + + if in_channels is None: + in_channels = self._check_input_shape(input_shape) + + self.in_channels = in_channels + + self.conv = nn.Conv1d( + in_channels, + out_channels, + self.kernel_size, + stride=self.stride, + dilation=self.dilation, + padding=default_padding, + groups=groups, + bias=bias, + ) + + if conv_init == "kaiming": + nn.init.kaiming_normal_(self.conv.weight) + elif conv_init == "zero": + nn.init.zeros_(self.conv.weight) + elif conv_init == "normal": + nn.init.normal_(self.conv.weight, std=1e-6) + + if weight_norm: + self.conv = nn.utils.weight_norm(self.conv) + + def forward(self, x): + """Returns the output of the convolution. + + Arguments + --------- + x : torch.Tensor (batch, time, channel) + input to convolve. 2d or 4d tensors are expected. + + Returns + ------- + wx : torch.Tensor + The convolved outputs. + """ + if not self.skip_transpose: + x = x.transpose(1, -1) + + if self.unsqueeze: + x = x.unsqueeze(1) + + if self.padding == "same": + x = self._manage_padding( + x, self.kernel_size, self.dilation, self.stride + ) + + elif self.padding == "causal": + num_pad = (self.kernel_size - 1) * self.dilation + x = F.pad(x, (num_pad, 0)) + + elif self.padding == "valid": + pass + + else: + raise ValueError( + "Padding must be 'same', 'valid' or 'causal'. Got " + + self.padding + ) + + wx = self.conv(x) + + if self.unsqueeze: + wx = wx.squeeze(1) + + if not self.skip_transpose: + wx = wx.transpose(1, -1) + + return wx + + def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): + """This function performs zero-padding on the time axis + such that their lengths is unchanged after the convolution. + + Arguments + --------- + x : torch.Tensor + Input tensor. + kernel_size : int + Size of kernel. + dilation : int + Dilation used. + stride : int + Stride. + + Returns + ------- + x : torch.Tensor + The padded outputs. + """ + + # Detecting input shape + L_in = self.in_channels + + # Time padding + padding = get_padding_elem(L_in, stride, kernel_size, dilation) + + # Applying padding + x = F.pad(x, padding, mode=self.padding_mode) + + return x + + def _check_input_shape(self, shape): + """Checks the input shape and returns the number of input channels.""" + + if len(shape) == 2: + self.unsqueeze = True + in_channels = 1 + elif self.skip_transpose: + in_channels = shape[1] + elif len(shape) == 3: + in_channels = shape[2] + else: + raise ValueError( + "conv1d expects 2d, 3d inputs. Got " + str(len(shape)) + ) + + # Kernel size must be odd + if not self.padding == "valid" and self.kernel_size % 2 == 0: + raise ValueError( + "The field kernel size must be an odd number. Got %s." + % (self.kernel_size) + ) + + return in_channels + + def remove_weight_norm(self): + """Removes weight normalization at inference if used during training.""" + self.conv = nn.utils.remove_weight_norm(self.conv) + + +def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): + """This function computes the number of elements to add for zero-padding. + + Arguments + --------- + L_in : int + stride: int + kernel_size : int + dilation : int + + Returns + ------- + padding : int + The size of the padding to be added + """ + if stride > 1: + padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)] + + else: + L_out = ( + math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1 + ) + padding = [ + math.floor((L_in - L_out) / 2), + math.floor((L_in - L_out) / 2), + ] + return padding + diff --git a/indextts/BigVGAN/nnet/__init__.py b/indextts/BigVGAN/nnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indextts/BigVGAN/nnet/linear.py b/indextts/BigVGAN/nnet/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..de808bad030f2e07209a026cfbff5a67a677603b --- /dev/null +++ b/indextts/BigVGAN/nnet/linear.py @@ -0,0 +1,89 @@ +"""Library implementing linear transformation. + +Authors + * Mirco Ravanelli 2020 + * Davide Borra 2021 +""" + +import logging + +import torch +import torch.nn as nn + + +class Linear(torch.nn.Module): + """Computes a linear transformation y = wx + b. + + Arguments + --------- + n_neurons : int + It is the number of output neurons (i.e, the dimensionality of the + output). + input_shape : tuple + It is the shape of the input tensor. + input_size : int + Size of the input tensor. + bias : bool + If True, the additive bias b is adopted. + max_norm : float + weight max-norm. + combine_dims : bool + If True and the input is 4D, combine 3rd and 4th dimensions of input. + + Example + ------- + >>> inputs = torch.rand(10, 50, 40) + >>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100) + >>> output = lin_t(inputs) + >>> output.shape + torch.Size([10, 50, 100]) + """ + + def __init__( + self, + n_neurons, + input_shape=None, + input_size=None, + bias=True, + max_norm=None, + combine_dims=False, + ): + super().__init__() + self.max_norm = max_norm + self.combine_dims = combine_dims + + if input_shape is None and input_size is None: + raise ValueError("Expected one of input_shape or input_size") + + if input_size is None: + input_size = input_shape[-1] + if len(input_shape) == 4 and self.combine_dims: + input_size = input_shape[2] * input_shape[3] + + # Weights are initialized following pytorch approach + self.w = nn.Linear(input_size, n_neurons, bias=bias) + + def forward(self, x): + """Returns the linear transformation of input tensor. + + Arguments + --------- + x : torch.Tensor + Input to transform linearly. + + Returns + ------- + wx : torch.Tensor + The linearly transformed outputs. + """ + if x.ndim == 4 and self.combine_dims: + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) + + if self.max_norm is not None: + self.w.weight.data = torch.renorm( + self.w.weight.data, p=2, dim=0, maxnorm=self.max_norm + ) + + wx = self.w(x) + + return wx diff --git a/indextts/BigVGAN/nnet/normalization.py b/indextts/BigVGAN/nnet/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..bd8d9193fc162e7df76220deb36779ffbf241fb9 --- /dev/null +++ b/indextts/BigVGAN/nnet/normalization.py @@ -0,0 +1,670 @@ +"""Library implementing normalization. + +Authors + * Mirco Ravanelli 2020 + * Guillermo Cámbara 2021 + * Sarthak Yadav 2022 +""" + +import torch +import torch.nn as nn + + +class BatchNorm1d(nn.Module): + """Applies 1d batch normalization to the input tensor. + + Arguments + --------- + input_shape : tuple + The expected shape of the input. Alternatively, use ``input_size``. + input_size : int + The expected size of the input. Alternatively, use ``input_shape``. + eps : float + This value is added to std deviation estimation to improve the numerical + stability. + momentum : float + It is a value used for the running_mean and running_var computation. + affine : bool + When set to True, the affine parameters are learned. + track_running_stats : bool + When set to True, this module tracks the running mean and variance, + and when set to False, this module does not track such statistics. + combine_batch_time : bool + When true, it combines batch an time axis. + skip_transpose : bool + Whether to skip the transposition. + + + Example + ------- + >>> input = torch.randn(100, 10) + >>> norm = BatchNorm1d(input_shape=input.shape) + >>> output = norm(input) + >>> output.shape + torch.Size([100, 10]) + """ + + def __init__( + self, + input_shape=None, + input_size=None, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + combine_batch_time=False, + skip_transpose=False, + ): + super().__init__() + self.combine_batch_time = combine_batch_time + self.skip_transpose = skip_transpose + + if input_size is None and skip_transpose: + input_size = input_shape[1] + elif input_size is None: + input_size = input_shape[-1] + + self.norm = nn.BatchNorm1d( + input_size, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + + def forward(self, x): + """Returns the normalized input tensor. + + Arguments + --------- + x : torch.Tensor (batch, time, [channels]) + input to normalize. 2d or 3d tensors are expected in input + 4d tensors can be used when combine_dims=True. + + Returns + ------- + x_n : torch.Tensor + The normalized outputs. + """ + shape_or = x.shape + if self.combine_batch_time: + if x.ndim == 3: + x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) + else: + x = x.reshape( + shape_or[0] * shape_or[1], shape_or[3], shape_or[2] + ) + + elif not self.skip_transpose: + x = x.transpose(-1, 1) + + x_n = self.norm(x) + + if self.combine_batch_time: + x_n = x_n.reshape(shape_or) + elif not self.skip_transpose: + x_n = x_n.transpose(1, -1) + + return x_n + + +class BatchNorm2d(nn.Module): + """Applies 2d batch normalization to the input tensor. + + Arguments + --------- + input_shape : tuple + The expected shape of the input. Alternatively, use ``input_size``. + input_size : int + The expected size of the input. Alternatively, use ``input_shape``. + eps : float + This value is added to std deviation estimation to improve the numerical + stability. + momentum : float + It is a value used for the running_mean and running_var computation. + affine : bool + When set to True, the affine parameters are learned. + track_running_stats : bool + When set to True, this module tracks the running mean and variance, + and when set to False, this module does not track such statistics. + + Example + ------- + >>> input = torch.randn(100, 10, 5, 20) + >>> norm = BatchNorm2d(input_shape=input.shape) + >>> output = norm(input) + >>> output.shape + torch.Size([100, 10, 5, 20]) + """ + + def __init__( + self, + input_shape=None, + input_size=None, + eps=1e-05, + momentum=0.1, + affine=True, + track_running_stats=True, + ): + super().__init__() + + if input_shape is None and input_size is None: + raise ValueError("Expected input_shape or input_size as input") + + if input_size is None: + input_size = input_shape[-1] + + self.norm = nn.BatchNorm2d( + input_size, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + ) + + def forward(self, x): + """Returns the normalized input tensor. + + Arguments + --------- + x : torch.Tensor (batch, time, channel1, channel2) + input to normalize. 4d tensors are expected. + + Returns + ------- + x_n : torch.Tensor + The normalized outputs. + """ + x = x.transpose(-1, 1) + x_n = self.norm(x) + x_n = x_n.transpose(1, -1) + + return x_n + + +class LayerNorm(nn.Module): + """Applies layer normalization to the input tensor. + + Arguments + --------- + input_size : int + The expected size of the dimension to be normalized. + input_shape : tuple + The expected shape of the input. + eps : float + This value is added to std deviation estimation to improve the numerical + stability. + elementwise_affine : bool + If True, this module has learnable per-element affine parameters + initialized to ones (for weights) and zeros (for biases). + + Example + ------- + >>> input = torch.randn(100, 101, 128) + >>> norm = LayerNorm(input_shape=input.shape) + >>> output = norm(input) + >>> output.shape + torch.Size([100, 101, 128]) + """ + + def __init__( + self, + input_size=None, + input_shape=None, + eps=1e-05, + elementwise_affine=True, + ): + super().__init__() + self.eps = eps + self.elementwise_affine = elementwise_affine + + if input_shape is not None: + input_size = input_shape[2:] + + self.norm = torch.nn.LayerNorm( + input_size, + eps=self.eps, + elementwise_affine=self.elementwise_affine, + ) + + def forward(self, x): + """Returns the normalized input tensor. + + Arguments + --------- + x : torch.Tensor (batch, time, channels) + input to normalize. 3d or 4d tensors are expected. + + Returns + ------- + The normalized outputs. + """ + return self.norm(x) + + +class InstanceNorm1d(nn.Module): + """Applies 1d instance normalization to the input tensor. + + Arguments + --------- + input_shape : tuple + The expected shape of the input. Alternatively, use ``input_size``. + input_size : int + The expected size of the input. Alternatively, use ``input_shape``. + eps : float + This value is added to std deviation estimation to improve the numerical + stability. + momentum : float + It is a value used for the running_mean and running_var computation. + track_running_stats : bool + When set to True, this module tracks the running mean and variance, + and when set to False, this module does not track such statistics. + affine : bool + A boolean value that when set to True, this module has learnable + affine parameters, initialized the same way as done for + batch normalization. Default: False. + + Example + ------- + >>> input = torch.randn(100, 10, 20) + >>> norm = InstanceNorm1d(input_shape=input.shape) + >>> output = norm(input) + >>> output.shape + torch.Size([100, 10, 20]) + """ + + def __init__( + self, + input_shape=None, + input_size=None, + eps=1e-05, + momentum=0.1, + track_running_stats=True, + affine=False, + ): + super().__init__() + + if input_shape is None and input_size is None: + raise ValueError("Expected input_shape or input_size as input") + + if input_size is None: + input_size = input_shape[-1] + + self.norm = nn.InstanceNorm1d( + input_size, + eps=eps, + momentum=momentum, + track_running_stats=track_running_stats, + affine=affine, + ) + + def forward(self, x): + """Returns the normalized input tensor. + + Arguments + --------- + x : torch.Tensor (batch, time, channels) + input to normalize. 3d tensors are expected. + + Returns + ------- + x_n : torch.Tensor + The normalized outputs. + """ + x = x.transpose(-1, 1) + x_n = self.norm(x) + x_n = x_n.transpose(1, -1) + + return x_n + + +class InstanceNorm2d(nn.Module): + """Applies 2d instance normalization to the input tensor. + + Arguments + --------- + input_shape : tuple + The expected shape of the input. Alternatively, use ``input_size``. + input_size : int + The expected size of the input. Alternatively, use ``input_shape``. + eps : float + This value is added to std deviation estimation to improve the numerical + stability. + momentum : float + It is a value used for the running_mean and running_var computation. + track_running_stats : bool + When set to True, this module tracks the running mean and variance, + and when set to False, this module does not track such statistics. + affine : bool + A boolean value that when set to True, this module has learnable + affine parameters, initialized the same way as done for + batch normalization. Default: False. + + Example + ------- + >>> input = torch.randn(100, 10, 20, 2) + >>> norm = InstanceNorm2d(input_shape=input.shape) + >>> output = norm(input) + >>> output.shape + torch.Size([100, 10, 20, 2]) + """ + + def __init__( + self, + input_shape=None, + input_size=None, + eps=1e-05, + momentum=0.1, + track_running_stats=True, + affine=False, + ): + super().__init__() + + if input_shape is None and input_size is None: + raise ValueError("Expected input_shape or input_size as input") + + if input_size is None: + input_size = input_shape[-1] + + self.norm = nn.InstanceNorm2d( + input_size, + eps=eps, + momentum=momentum, + track_running_stats=track_running_stats, + affine=affine, + ) + + def forward(self, x): + """Returns the normalized input tensor. + + Arguments + --------- + x : torch.Tensor (batch, time, channel1, channel2) + input to normalize. 4d tensors are expected. + + Returns + ------- + x_n : torch.Tensor + The normalized outputs. + """ + x = x.transpose(-1, 1) + x_n = self.norm(x) + x_n = x_n.transpose(1, -1) + + return x_n + + +class GroupNorm(nn.Module): + """Applies group normalization to the input tensor. + + Arguments + --------- + input_shape : tuple + The expected shape of the input. Alternatively, use ``input_size``. + input_size : int + The expected size of the input. Alternatively, use ``input_shape``. + num_groups : int + Number of groups to separate the channels into. + eps : float + This value is added to std deviation estimation to improve the numerical + stability. + affine : bool + A boolean value that when set to True, this module has learnable per-channel + affine parameters initialized to ones (for weights) and zeros (for biases). + + Example + ------- + >>> input = torch.randn(100, 101, 128) + >>> norm = GroupNorm(input_size=128, num_groups=128) + >>> output = norm(input) + >>> output.shape + torch.Size([100, 101, 128]) + """ + + def __init__( + self, + input_shape=None, + input_size=None, + num_groups=None, + eps=1e-05, + affine=True, + ): + super().__init__() + self.eps = eps + self.affine = affine + + if input_shape is None and input_size is None: + raise ValueError("Expected input_shape or input_size as input") + + if num_groups is None: + raise ValueError("Expected num_groups as input") + + if input_shape is not None: + input_size = input_shape[-1] + + self.norm = torch.nn.GroupNorm( + num_groups, + input_size, + eps=self.eps, + affine=self.affine, + ) + + def forward(self, x): + """Returns the normalized input tensor. + + Arguments + --------- + x : torch.Tensor (batch, time, channels) + input to normalize. 3d or 4d tensors are expected. + + Returns + ------- + x_n : torch.Tensor + The normalized outputs. + """ + x = x.transpose(-1, 1) + x_n = self.norm(x) + x_n = x_n.transpose(1, -1) + + return x_n + + +class ExponentialMovingAverage(nn.Module): + """ + Applies learnable exponential moving average, as required by learnable PCEN layer + + Arguments + --------- + input_size : int + The expected size of the input. + coeff_init: float + Initial smoothing coefficient value + per_channel: bool + Controls whether every smoothing coefficients are learned + independently for every input channel + trainable: bool + whether to learn the PCEN parameters or use fixed + skip_transpose : bool + If False, uses batch x time x channel convention of speechbrain. + If True, uses batch x channel x time convention. + + Example + ------- + >>> inp_tensor = torch.rand([10, 50, 40]) + >>> pcen = ExponentialMovingAverage(40) + >>> out_tensor = pcen(inp_tensor) + >>> out_tensor.shape + torch.Size([10, 50, 40]) + """ + + def __init__( + self, + input_size: int, + coeff_init: float = 0.04, + per_channel: bool = False, + trainable: bool = True, + skip_transpose: bool = False, + ): + super().__init__() + self._coeff_init = coeff_init + self._per_channel = per_channel + self.skip_transpose = skip_transpose + self.trainable = trainable + weights = ( + torch.ones( + input_size, + ) + if self._per_channel + else torch.ones( + 1, + ) + ) + self._weights = nn.Parameter( + weights * self._coeff_init, requires_grad=trainable + ) + + def forward(self, x): + """Returns the normalized input tensor. + + Arguments + --------- + x : torch.Tensor (batch, time, channels) + input to normalize. + """ + if not self.skip_transpose: + x = x.transpose(1, -1) + w = torch.clamp(self._weights, min=0.0, max=1.0) + initial_state = x[:, :, 0] + + def scan(init_state, x, w): + """Loops and accumulates.""" + x = x.permute(2, 0, 1) + acc = init_state + results = [] + for ix in range(x.shape[0]): + acc = (w * x[ix]) + ((1.0 - w) * acc) + results.append(acc.unsqueeze(0)) + results = torch.cat(results, dim=0) + results = results.permute(1, 2, 0) + return results + + output = scan(initial_state, x, w) + if not self.skip_transpose: + output = output.transpose(1, -1) + return output + + +class PCEN(nn.Module): + """ + This class implements a learnable Per-channel energy normalization (PCEN) layer, supporting both + original PCEN as specified in [1] as well as sPCEN as specified in [2] + + [1] Yuxuan Wang, Pascal Getreuer, Thad Hughes, Richard F. Lyon, Rif A. Saurous, "Trainable Frontend For + Robust and Far-Field Keyword Spotting", in Proc of ICASSP 2017 (https://arxiv.org/abs/1607.05666) + + [2] Neil Zeghidour, Olivier Teboul, F{\'e}lix de Chaumont Quitry & Marco Tagliasacchi, "LEAF: A LEARNABLE FRONTEND + FOR AUDIO CLASSIFICATION", in Proc of ICLR 2021 (https://arxiv.org/abs/2101.08596) + + The default argument values correspond with those used by [2]. + + Arguments + --------- + input_size : int + The expected size of the input. + alpha: float + specifies alpha coefficient for PCEN + smooth_coef: float + specified smooth coefficient for PCEN + delta: float + specifies delta coefficient for PCEN + root: float + specifies root coefficient for PCEN + floor: float + specifies floor coefficient for PCEN + trainable: bool + whether to learn the PCEN parameters or use fixed + per_channel_smooth_coef: bool + whether to learn independent smooth coefficients for every channel. + when True, essentially using sPCEN from [2] + skip_transpose : bool + If False, uses batch x time x channel convention of speechbrain. + If True, uses batch x channel x time convention. + + Example + ------- + >>> inp_tensor = torch.rand([10, 50, 40]) + >>> pcen = PCEN(40, alpha=0.96) # sPCEN + >>> out_tensor = pcen(inp_tensor) + >>> out_tensor.shape + torch.Size([10, 50, 40]) + """ + + def __init__( + self, + input_size, + alpha: float = 0.96, + smooth_coef: float = 0.04, + delta: float = 2.0, + root: float = 2.0, + floor: float = 1e-12, + trainable: bool = True, + per_channel_smooth_coef: bool = True, + skip_transpose: bool = False, + ): + super().__init__() + self._smooth_coef = smooth_coef + self._floor = floor + self._per_channel_smooth_coef = per_channel_smooth_coef + self.skip_transpose = skip_transpose + self.alpha = nn.Parameter( + torch.ones(input_size) * alpha, requires_grad=trainable + ) + self.delta = nn.Parameter( + torch.ones(input_size) * delta, requires_grad=trainable + ) + self.root = nn.Parameter( + torch.ones(input_size) * root, requires_grad=trainable + ) + + self.ema = ExponentialMovingAverage( + input_size, + coeff_init=self._smooth_coef, + per_channel=self._per_channel_smooth_coef, + skip_transpose=True, + trainable=trainable, + ) + + def forward(self, x): + """Returns the normalized input tensor. + + Arguments + --------- + x : torch.Tensor (batch, time, channels) + input to normalize. + + Returns + ------- + output : torch.Tensor + The normalized outputs. + """ + if not self.skip_transpose: + x = x.transpose(1, -1) + alpha = torch.min( + self.alpha, torch.tensor(1.0, dtype=x.dtype, device=x.device) + ) + root = torch.max( + self.root, torch.tensor(1.0, dtype=x.dtype, device=x.device) + ) + ema_smoother = self.ema(x) + one_over_root = 1.0 / root + output = ( + x / (self._floor + ema_smoother) ** alpha.view(1, -1, 1) + + self.delta.view(1, -1, 1) + ) ** one_over_root.view(1, -1, 1) - self.delta.view( + 1, -1, 1 + ) ** one_over_root.view( + 1, -1, 1 + ) + if not self.skip_transpose: + output = output.transpose(1, -1) + return output diff --git a/indextts/BigVGAN/utils.py b/indextts/BigVGAN/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e968fd49c045c01c5974c3a10e0818ac1391556b --- /dev/null +++ b/indextts/BigVGAN/utils.py @@ -0,0 +1,101 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import glob +import os + +import matplotlib +import matplotlib.pylab as plt +import torch +from scipy.io.wavfile import write +from torch.nn.utils import weight_norm + +matplotlib.use("Agg") + +MAX_WAV_VALUE = 32768.0 + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def plot_spectrogram_clipped(spectrogram, clip_max=2.0): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow( + spectrogram, + aspect="auto", + origin="lower", + interpolation="none", + vmin=1e-6, + vmax=clip_max, + ) + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print(f"Loading '{filepath}'") + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print(f"Saving checkpoint to {filepath}") + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix, renamed_file=None): + # Fallback to original scanning logic first + pattern = os.path.join(cp_dir, prefix + "????????") + cp_list = glob.glob(pattern) + + if len(cp_list) > 0: + last_checkpoint_path = sorted(cp_list)[-1] + print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'") + return last_checkpoint_path + + # If no pattern-based checkpoints are found, check for renamed file + if renamed_file: + renamed_path = os.path.join(cp_dir, renamed_file) + if os.path.isfile(renamed_path): + print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'") + return renamed_path + + return None + + +def save_audio(audio, path, sr): + # wav: torch with 1d shape + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype("int16") + write(path, sr, audio) diff --git a/indextts/__init__.py b/indextts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indextts/cli.py b/indextts/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e2f5183a3cf4da0bd33cd3ad093636b5a89979 --- /dev/null +++ b/indextts/cli.py @@ -0,0 +1,62 @@ +import os +import sys +import warnings +# Suppress warnings from tensorflow and other libraries +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) +def main(): + import argparse + parser = argparse.ArgumentParser(description="IndexTTS Command Line") + parser.add_argument("text", type=str, help="Text to be synthesized") + parser.add_argument("-v", "--voice", type=str, required=True, help="Path to the audio prompt file (wav format)") + parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file") + parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'") + parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'") + parser.add_argument("--fp16", action="store_true", default=True, help="Use FP16 for inference if available") + parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists") + parser.add_argument("-d", "--device", type=str, default=None, help="Device to run the model on (cpu, cuda, mps)." ) + args = parser.parse_args() + if len(args.text.strip()) == 0: + print("ERROR: Text is empty.") + parser.print_help() + sys.exit(1) + if not os.path.exists(args.voice): + print(f"Audio prompt file {args.voice} does not exist.") + parser.print_help() + sys.exit(1) + if not os.path.exists(args.config): + print(f"Config file {args.config} does not exist.") + parser.print_help() + sys.exit(1) + + output_path = args.output_path + if os.path.exists(output_path): + if not args.force: + print(f"ERROR: Output file {output_path} already exists. Use --force to overwrite.") + parser.print_help() + sys.exit(1) + else: + os.remove(output_path) + + try: + import torch + except ImportError: + print("ERROR: PyTorch is not installed. Please install it first.") + sys.exit(1) + + if args.device is None: + if torch.cuda.is_available(): + args.device = "cuda:0" + elif torch.mps.is_available(): + args.device = "mps" + else: + args.device = "cpu" + args.fp16 = False # Disable FP16 on CPU + print("WARNING: Running on CPU may be slow.") + + from indextts.infer import IndexTTS + tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir, is_fp16=args.fp16, device=args.device) + tts.infer(audio_prompt=args.voice, text=args.text.strip(), output_path=output_path) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/indextts/gpt/__init__.py b/indextts/gpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indextts/gpt/conformer/__init__.py b/indextts/gpt/conformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indextts/gpt/conformer/attention.py b/indextts/gpt/conformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..6ee5e313edf2e88a844ce004c0f819b0bd3260f6 --- /dev/null +++ b/indextts/gpt/conformer/attention.py @@ -0,0 +1,312 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-Head Attention layer definition.""" + +import math +from typing import Tuple + +import torch +from torch import nn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + def __init__(self, n_head: int, n_feat: int, dropout_rate: float): + """Construct an MultiHeadedAttention object.""" + super().__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor, size + (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor, size + (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor, size + (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention( + self, value: torch.Tensor, scores: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) + ) -> torch.Tensor: + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + # NOTE(xcsong): When will `if mask.size(2) > 0` be True? + # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the + # 1st chunk to ease the onnx export.] + # 2. pytorch training + if mask.size(2) > 0 : # time2 > 0 + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + # For last chunk, time2 might be larger than scores.size(-1) + mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) + scores = scores.masked_fill(mask, -float('inf')) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0) # (batch, head, time1, time2) + # NOTE(xcsong): When will `if mask.size(2) > 0` be False? + # 1. onnx(16/-1, -1/-1, 16/0) + # 2. jit (16/-1, -1/-1, 16/0, 16/4) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + 1.When applying cross attention between decoder and encoder, + the batch padding mask for input is in (#batch, 1, T) shape. + 2.When applying self attention of encoder, + the mask is in (#batch, T, T) shape. + 3.When applying self attention of decoder, + the mask is in (#batch, L, L) shape. + 4.If the different position in decoder see different block + of the encoder, such as Mocha, the passed in mask could be + in (#batch, L, T) shape. But there is no such case in current + Wenet. + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + """ + q, k, v = self.forward_qkv(query, key, value) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split( + cache, cache.size(-1) // 2, dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask), new_cache + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x, zero_triu: bool = False): + """Compute relative positinal encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, size). + zero_triu (bool): If true, return the lower triangular part of + the matrix. + Returns: + torch.Tensor: Output tensor. + """ + + zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(x.size()[0], + x.size()[1], + x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) + + if zero_triu: + ones = torch.ones((x.size(2), x.size(3))) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split( + cache, cache.size(-1) // 2, dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + # Remove rel_shift since it is useless in speech recognition, + # and it requires special attention for streaming. + # matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask), new_cache diff --git a/indextts/gpt/conformer/embedding.py b/indextts/gpt/conformer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..97565e96db5d16d9c2ddb7adc850a68ee7e2ed77 --- /dev/null +++ b/indextts/gpt/conformer/embedding.py @@ -0,0 +1,163 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) + +"""Positonal Encoding Module.""" + +import math +from typing import Tuple, Union + +import torch +import torch.nn.functional as F + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) + PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) + """ + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + reverse: bool = False): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.max_len = max_len + + pe = torch.zeros(self.max_len, self.d_model) + position = torch.arange(0, self.max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2) * + -(math.log(10000.0) / self.d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + offset (int, torch.tensor): position offset + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + torch.Tensor: for compatibility to RelPositionalEncoding + """ + + self.pe = self.pe.to(x.device) + pos_emb = self.position_encoding(offset, x.size(1), False) + x = x * self.xscale + pos_emb + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding(self, offset: Union[int, torch.Tensor], size: int, + apply_dropout: bool = True) -> torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + # How to subscript a Union type: + # https://github.com/pytorch/pytorch/issues/69434 + if isinstance(offset, int): + assert offset + size < self.max_len + pos_emb = self.pe[:, offset:offset + size] + elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar + assert offset + size < self.max_len + pos_emb = self.pe[:, offset:offset + size] + else: # for batched streaming decoding on GPU + assert torch.max(offset) + size < self.max_len + index = offset.unsqueeze(1) + \ + torch.arange(0, size).to(offset.device) # B X T + flag = index > 0 + # remove negative offset + index = index * flag + pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model + + if apply_dropout: + pos_emb = self.dropout(pos_emb) + return pos_emb + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Initialize class.""" + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + """ + self.pe = self.pe.to(x.device) + x = x * self.xscale + pos_emb = self.position_encoding(offset, x.size(1), False) + return self.dropout(x), self.dropout(pos_emb) + + +class NoPositionalEncoding(torch.nn.Module): + """ No position encoding + """ + def __init__(self, d_model: int, dropout_rate: float): + super().__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + + def forward(self, + x: torch.Tensor, + offset: Union[int, torch.Tensor] = 0) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """ Just return zero vector for interface compatibility + """ + pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) + return self.dropout(x), pos_emb + + def position_encoding( + self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor: + return torch.zeros(1, size, self.d_model) diff --git a/indextts/gpt/conformer/subsampling.py b/indextts/gpt/conformer/subsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..363b4b18a4aa0e7fd755aa4b1ed15ede852d4b1e --- /dev/null +++ b/indextts/gpt/conformer/subsampling.py @@ -0,0 +1,348 @@ +# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) + + +"""Subsampling layer definition.""" + +from typing import Tuple, Union + +import torch + + +class BaseSubsampling(torch.nn.Module): + def __init__(self): + super().__init__() + self.right_context = 0 + self.subsampling_rate = 1 + + def position_encoding(self, offset: Union[int, torch.Tensor], + size: int) -> torch.Tensor: + return self.pos_enc.position_encoding(offset, size) + + +class LinearNoSubsampling(BaseSubsampling): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an linear object.""" + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + torch.nn.LayerNorm(odim, eps=1e-5), + torch.nn.Dropout(dropout_rate), + ) + self.pos_enc = pos_enc_class + self.right_context = 0 + self.subsampling_rate = 1 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.out(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask + + +class Conv2dSubsampling3(BaseSubsampling): + """Convolutional 2D subsampling (to 1/3 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling3 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 5, 3), + torch.nn.ReLU() + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * ((idim - 2) // 3), odim)) + self.pos_enc = pos_enc_class + # The right context for every conv layer is computed by: + # (kernel_size - 1) * frame_rate_of_this_layer + self.subsampling_rate = 3 + # 4 = (5 - 1) * 1 + self.right_context = 4 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 3. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 3. + torch.Tensor: positional encoding + + """ + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, :-2:3] + + +class Conv2dSubsampling2(BaseSubsampling): + """Convolutional 2D subsampling (to 1/2 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling4 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * ((idim - 1) // 2), odim)) + self.pos_enc = pos_enc_class + # The right context for every conv layer is computed by: + # (kernel_size - 1) * frame_rate_of_this_layer + self.subsampling_rate = 2 + # 2 = (3 - 1) * 1 + self.right_context = 2 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 2. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 2. + torch.Tensor: positional encoding + + """ + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2] + + +class Conv2dSubsampling4(BaseSubsampling): + """Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling4 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) + self.pos_enc = pos_enc_class + # The right context for every conv layer is computed by: + # (kernel_size - 1) * frame_rate_of_this_layer + self.subsampling_rate = 4 + # 6 = (3 - 1) * 1 + (3 - 1) * 2 + self.right_context = 6 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + torch.Tensor: positional encoding + + """ + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2] + + +class Conv2dSubsampling6(BaseSubsampling): + """Convolutional 2D subsampling (to 1/6 length). + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + """ + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling6 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 5, 3), + torch.nn.ReLU(), + ) + self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), + odim) + self.pos_enc = pos_enc_class + # 10 = (3 - 1) * 1 + (5 - 1) * 2 + self.subsampling_rate = 6 + self.right_context = 10 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 6. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 6. + torch.Tensor: positional encoding + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3] + + +class Conv2dSubsampling8(BaseSubsampling): + """Convolutional 2D subsampling (to 1/8 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling8 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.linear = torch.nn.Linear( + odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim) + self.pos_enc = pos_enc_class + self.subsampling_rate = 8 + # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4 + self.right_context = 14 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: Union[int, torch.Tensor] = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 8. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 8. + torch.Tensor: positional encoding + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2] diff --git a/indextts/gpt/conformer_encoder.py b/indextts/gpt/conformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a6967e619132eff9065c8c13029723a9b0f07c14 --- /dev/null +++ b/indextts/gpt/conformer_encoder.py @@ -0,0 +1,520 @@ + +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from indextts.gpt.conformer.attention import (MultiHeadedAttention, + RelPositionMultiHeadedAttention) +from indextts.gpt.conformer.embedding import (NoPositionalEncoding, + PositionalEncoding, + RelPositionalEncoding) +from indextts.gpt.conformer.subsampling import (Conv2dSubsampling2, + Conv2dSubsampling4, + Conv2dSubsampling6, + Conv2dSubsampling8, + LinearNoSubsampling) +from indextts.utils.common import make_pad_mask + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + FeedForward are appied on each position of the sequence. + The output dim is same with the input dim. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (torch.nn.Module): Activation function + """ + + def __init__(self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU()): + """Construct a PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.activation = activation + self.dropout = torch.nn.Dropout(dropout_rate) + self.w_2 = torch.nn.Linear(hidden_units, idim) + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + """ + return self.w_2(self.dropout(self.activation(self.w_1(xs)))) + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model.""" + + def __init__(self, + channels: int, + kernel_size: int = 15, + activation: nn.Module = nn.ReLU(), + bias: bool = True): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + causal (int): Whether use causal convolution or not + """ + super().__init__() + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: it's a causal convolution, the input will be + # padded with self.lorder frames on the left in forward. + # else: it's a symmetrical convolution + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + + self.use_layer_norm = True + self.norm = nn.LayerNorm(channels) + + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + + def forward( + self, + x: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + cache: torch.Tensor = torch.zeros((0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) # (#batch, channels, time) + + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + if self.lorder > 0: + if cache.size(2) == 0: # cache_t == 0 + x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) + else: + assert cache.size(0) == x.size(0) # equal batch + assert cache.size(1) == x.size(1) # equal channel + x = torch.cat((cache, x), dim=2) + assert (x.size(2) > self.lorder) + new_cache = x[:, :, -self.lorder:] + else: + # It's better we just return None if no cache is required, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.pointwise_conv2(x) + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + return x.transpose(1, 2), new_cache + + +class ConformerEncoderLayer(nn.Module): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module + instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + concat_after (bool): Whether to concat attention layer's input and + output. + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: Optional[nn.Module] = None, + feed_forward_macaron: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + dropout_rate: float = 0.1, + normalize_before: bool = True, + concat_after: bool = False, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module + self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = nn.LayerNorm(size, + eps=1e-5) # for the CNN module + self.norm_final = nn.LayerNorm( + size, eps=1e-5) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + else: + self.concat_linear = nn.Identity() + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): positional encoding, must not be None + for ConformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2) + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + + x_att, new_att_cache = self.self_attn( + x, x, x, mask, pos_emb, att_cache) + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + return x, mask, new_att_cache, new_cnn_cache + + +class BaseEncoder(torch.nn.Module): + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + concat_after: bool = False, + ): + """ + Args: + input_size (int): input dim + output_size (int): dimension of attention + attention_heads (int): the number of heads of multi head attention + linear_units (int): the hidden units number of position-wise feed + forward + num_blocks (int): the number of decoder blocks + dropout_rate (float): dropout rate + attention_dropout_rate (float): dropout rate in attention + positional_dropout_rate (float): dropout rate after adding + positional encoding + input_layer (str): input layer type. + optional [linear, conv2d, conv2d6, conv2d8] + pos_enc_layer_type (str): Encoder positional encoding layer type. + opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] + normalize_before (bool): + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + concat_after (bool): whether to concat attention layer's input + and output. + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + static_chunk_size (int): chunk size for static chunk training and + decoding + use_dynamic_chunk (bool): whether use dynamic chunk size for + training or not, You can only use fixed chunk(chunk_size > 0) + or dyanmic chunk size(use_dynamic_chunk = True) + global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module + use_dynamic_left_chunk (bool): whether use dynamic left chunk in + dynamic chunk training + """ + super().__init__() + self._output_size = output_size + + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "rel_pos": + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "no_pos": + pos_enc_class = NoPositionalEncoding + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + if input_layer == "linear": + subsampling_class = LinearNoSubsampling + elif input_layer == "conv2d2": + subsampling_class = Conv2dSubsampling2 + elif input_layer == "conv2d": + subsampling_class = Conv2dSubsampling4 + elif input_layer == "conv2d6": + subsampling_class = Conv2dSubsampling6 + elif input_layer == "conv2d8": + subsampling_class = Conv2dSubsampling8 + else: + raise ValueError("unknown input_layer: " + input_layer) + + self.embed = subsampling_class( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, dropout_rate), + ) + + self.normalize_before = normalize_before + self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed positions in tensor. + + Args: + xs: padded input tensor (B, T, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + """ + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + xs, pos_emb, masks = self.embed(xs, masks) + chunk_masks = masks + mask_pad = masks # (B, 1, T/subsample_rate) + for layer in self.encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + if self.normalize_before: + xs = self.after_norm(xs) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + return xs, masks + + +class ConformerEncoder(BaseEncoder): + """Conformer encoder module.""" + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "rel_pos", + normalize_before: bool = True, + concat_after: bool = False, + macaron_style: bool = False, + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + ): + """Construct ConformerEncoder + + Args: + input_size to use_dynamic_chunk, see in BaseEncoder + positionwise_conv_kernel_size (int): Kernel size of positionwise + conv1d layer. + macaron_style (bool): Whether to use macaron style for + positionwise layer. + selfattention_layer_type (str): Encoder attention layer type, + the parameter has no effect now, it's just for configure + compatibility. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): whether to use causal convolution or not. + """ + + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + concat_after) + + activation = torch.nn.SiLU() + + # self-attention module definition + if pos_enc_layer_type != "rel_pos": + encoder_selfattn_layer = MultiHeadedAttention + else: + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + dropout_rate, + ) + + # feed-forward module definition + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = (output_size, + cnn_module_kernel, + activation,) + + self.encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer( + *positionwise_layer_args) if macaron_style else None, + convolution_layer( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + ) for _ in range(num_blocks) + ]) diff --git a/indextts/gpt/model.py b/indextts/gpt/model.py new file mode 100644 index 0000000000000000000000000000000000000000..5f3b25c1c1ddda8a89867e2c0f3782f9c4ae2ea4 --- /dev/null +++ b/indextts/gpt/model.py @@ -0,0 +1,713 @@ +import functools + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import transformers +from transformers import GPT2Config, LogitsProcessorList +from indextts.gpt.transformers_gpt2 import GPT2PreTrainedModel, GPT2Model + +# from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions +from transformers.utils.model_parallel_utils import (assert_device_map, + get_device_map) + +from indextts.gpt.conformer_encoder import ConformerEncoder +from indextts.gpt.perceiver import PerceiverResampler +from indextts.utils.arch_util import AttentionBlock +from indextts.utils.typical_sampling import TypicalLogitsWarper + + +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + +class ResBlock(nn.Module): + """ + Basic residual convolutional block that uses GroupNorm. + """ + + def __init__(self, chan): + super().__init__() + self.net = nn.Sequential( + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan // 8, chan), + nn.ReLU(), + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan // 8, chan) + ) + + def forward(self, x): + return F.relu(self.net(x) + x) + + +class GPT2InferenceModel(GPT2PreTrainedModel): + def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False): + super().__init__(config) + # Note: the argument named `text_pos_emb` here actually represents the mel position embedding + self.transformer = gpt + self.text_pos_embedding = text_pos_emb + self.embeddings = embeddings + self.final_norm = norm + self.lm_head = nn.Sequential(norm, linear) + self.kv_cache = kv_cache + + # Model parallel + self.model_parallel = False + self.device_map = None + self.cached_mel_emb = None + + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count()))) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def store_mel_emb(self, mel_emb): + self.cached_mel_emb = mel_emb + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # usually None + if not self.kv_cache: + past_key_values = None + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + assert self.cached_mel_emb is not None + assert inputs_embeds is None # Not supported by this inference model. + assert labels is None # Training not supported by this inference model. + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + # Create embedding + mel_len = self.cached_mel_emb.shape[1] + if input_ids.shape[1] != 1: + text_inputs = input_ids[:, mel_len:] + text_emb = self.embeddings(text_inputs) + text_emb = text_emb + self.text_pos_embedding(text_emb) + if self.cached_mel_emb.shape[0] != text_emb.shape[0]: + mel_emb = self.cached_mel_emb.repeat_interleave( + text_emb.shape[0] // self.cached_mel_emb.shape[0], 0 + ) + else: # this outcome only occurs once per loop in most cases + mel_emb = self.cached_mel_emb + emb = torch.cat([mel_emb, text_emb], dim=1) + else: + emb = self.embeddings(input_ids) + emb = emb + self.text_pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - mel_len, attention_mask.device + ) + transformer_outputs = self.transformer( + inputs_embeds=emb, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + if torch.backends.mps.is_available(): + self.to(self.transformer.first_device) + else: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + transformer_outputs[1:] + + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past, beam_idx): + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ) + for layer_past in past + ) + + +class ConditioningEncoder(nn.Module): + def __init__(self, + spec_dim, + embedding_dim, + attn_blocks=6, + num_attn_heads=4, + do_checkpointing=False, + mean=False): + super().__init__() + attn = [] + self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + self.do_checkpointing = do_checkpointing + self.mean = mean + + def forward(self, x): + h = self.init(x) + h = self.attn(h) + if self.mean: + return h.mean(dim=2) + else: + return h + # return h[:, :, 0] + + +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len, model_dim, init=.02): + super().__init__() + self.emb = nn.Embedding(seq_len, model_dim) + # Initializing this way is standard for GPT-2 + self.emb.weight.data.normal_(mean=0.0, std=init) + + def forward(self, x): + sl = x.shape[1] + return self.emb(torch.arange(0, sl, device=x.device)) + + def get_fixed_embedding(self, ind, dev): + return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) + + +def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, activation_function): + """ + GPT-2 implemented by the HuggingFace library. + """ + from transformers import GPT2Config, GPT2Model + gpt_config = GPT2Config(vocab_size=256, # Unused. + n_positions=max_mel_seq_len + max_text_seq_len, + n_ctx=max_mel_seq_len + max_text_seq_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + activation_function=activation_function or "gelu_new", + gradient_checkpointing=checkpointing, + use_cache=not checkpointing) + gpt = GPT2Model(gpt_config) + # Override the built in positional embeddings + del gpt.wpe + gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + # Built-in token embeddings are unused. + del gpt.wte + return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \ + None, None + + +class MelEncoder(nn.Module): + def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2): + super().__init__() + self.channels = channels + self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1), + nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 16, channels // 2), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 8, channels), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), + ) + self.reduction = 4 + + def forward(self, x): + for e in self.encoder: + x = e(x) + return x.permute(0, 2, 1) + + +class UnifiedVoice(nn.Module): + def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, + mel_length_compression=1024, number_text_tokens=256, + start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, + train_solo_embeddings=False, use_mel_codes_as_input=True, + checkpointing=True, types=1, activation_function=None, + condition_num_latent=32, condition_type="perceiver", condition_module=None): + """ + Args: + layers: Number of layers in transformer stack. + model_dim: Operating dimensions of the transformer + heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64 + max_text_tokens: Maximum number of text tokens that will be encountered by model. + max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. + max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). + mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. + number_text_tokens: + start_text_token: + stop_text_token: + number_mel_codes: + start_mel_token: + stop_mel_token: + train_solo_embeddings: + use_mel_codes_as_input: + checkpointing: + condition_type: perceiver, gst or default encoder + """ + super().__init__() + self.number_text_tokens = number_text_tokens + self.start_text_token = start_text_token + self.stop_text_token = stop_text_token + self.number_mel_codes = number_mel_codes + self.start_mel_token = start_mel_token + self.stop_mel_token = stop_mel_token + self.layers = layers + self.heads = heads + self.max_mel_tokens = max_mel_tokens + self.max_text_tokens = max_text_tokens + self.model_dim = model_dim + self.max_conditioning_inputs = max_conditioning_inputs + self.mel_length_compression = mel_length_compression + self.condition_type = condition_type + self.cond_num = condition_num_latent + self.cond_mask_pad = nn.ConstantPad1d((self.cond_num, 0), True) + if condition_type == "perceiver": + self.conditioning_encoder = ConditioningEncoder(100, model_dim, num_attn_heads=heads) + self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=model_dim, num_latents=self.cond_num) + elif condition_type == "conformer_perceiver" or condition_type == "conformer_encoder": + self.conditioning_encoder = ConformerEncoder(input_size=100, + output_size=condition_module['output_size'], + linear_units=condition_module['linear_units'], + attention_heads=condition_module['attention_heads'], + num_blocks=condition_module['num_blocks'], + input_layer=condition_module['input_layer']) + if condition_type == "conformer_perceiver": + self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=condition_module['output_size'], + ff_mult=condition_module['perceiver_mult'], + heads=condition_module['attention_heads'], + num_latents=self.cond_num) + else: + self.conditioning_encoder = ConditioningEncoder(100, model_dim, num_attn_heads=heads, mean=True) + + self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim) + if use_mel_codes_as_input: + self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) + else: + self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) + self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ + build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs, + self.max_text_tokens + 2, checkpointing, activation_function) + if train_solo_embeddings: + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) + else: + self.mel_solo_embedding = 0 + self.text_solo_embedding = 0 + + self.final_norm = nn.LayerNorm(model_dim) + self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1) + self.mel_head = nn.Linear(model_dim, self.number_mel_codes) + + # Initialize the embeddings per the GPT-2 scheme + embeddings = [self.text_embedding] + if use_mel_codes_as_input: + embeddings.append(self.mel_embedding) + for module in embeddings: + module.weight.data.normal_(mean=0.0, std=.02) + + def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False): + seq_length = self.max_mel_tokens + self.max_text_tokens + 2 + gpt_config = GPT2Config( + vocab_size=self.number_mel_codes, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + ) + self.inference_model = GPT2InferenceModel( + gpt_config, + self.gpt, + self.mel_pos_embedding, + self.mel_embedding, + self.final_norm, + self.mel_head, + kv_cache=kv_cache, + ) + if use_deepspeed and half and torch.cuda.is_available(): + import deepspeed + self.ds_engine = deepspeed.init_inference(model=self.inference_model, + mp_size=1, + replace_with_kernel_inject=False, + dtype=torch.float16) + self.inference_model = self.ds_engine.module.eval() + elif use_deepspeed and torch.cuda.is_available(): + import deepspeed + self.ds_engine = deepspeed.init_inference(model=self.inference_model, + mp_size=1, + replace_with_kernel_inject=False, + dtype=torch.float32) + self.inference_model = self.ds_engine.module.eval() + else: + self.inference_model = self.inference_model.eval() + + # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) + self.gpt.wte = self.mel_embedding + + def build_aligned_inputs_and_targets(self, input, start_token, stop_token): + inp = F.pad(input, (1, 0), value=start_token) + tar = F.pad(input, (0, 1), value=stop_token) + return inp, tar + + def set_mel_padding(self, mel_input_tokens, mel_lengths): + """ + Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in + that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required + preformatting to create a working TTS model. + """ + for b in range(len(mel_lengths)): + # Due to the convolutional nature of how these tokens are generated, + # it would be best if the model predicts a token past the actual last token. + actual_end = mel_lengths[b] + if actual_end < mel_input_tokens.shape[-1]: + mel_input_tokens[b, actual_end:] = self.stop_mel_token + return mel_input_tokens + + def set_text_padding(self, text_input_tokens, text_lengths): + """ + Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in + that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required + preformatting to create a working TTS model. + """ + for b in range(len(text_lengths)): + # Due to the convolutional nature of how these tokens are generated, + # it would be best if the model predicts a token past the actual last token. + actual_end = text_lengths[b] + if actual_end < text_input_tokens.shape[-1]: + text_input_tokens[b, actual_end:] = self.stop_text_token + return text_input_tokens + + def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False): + if second_inputs is not None: + emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) + + gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) + if get_attns: + return gpt_out.attentions + + offset = speech_conditioning_inputs.shape[1] + enc = gpt_out.last_hidden_state[:, offset:] + enc = self.final_norm(enc) + + if return_latent: + return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:] + + first_logits = enc[:, :first_inputs.shape[1]] + first_logits = first_head(first_logits) + first_logits = first_logits.permute(0, 2, 1) + if second_inputs is not None: + second_logits = enc[:, -second_inputs.shape[1]:] + second_logits = second_head(second_logits) + second_logits = second_logits.permute(0, 2, 1) + return first_logits, second_logits + else: + return first_logits + + def get_conditioning(self, speech_conditioning_input, cond_mel_lengths=None): + if self.condition_type == "perceiver": + if speech_conditioning_input.ndim == 4: + speech_conditioning_input = speech_conditioning_input.squeeze(1) + speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input) # (b, d, s) + conds = self.perceiver_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 32, d) + elif self.condition_type == "conformer_perceiver": + speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2), + cond_mel_lengths) # (b, s, d), (b, 1, s) + if self.condition_type == "conformer_perceiver": + # conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1) + conds_mask = self.cond_mask_pad(mask.squeeze(1)) + conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d) + elif self.condition_type == "gst": + if speech_conditioning_input.ndim == 4: + speech_conditioning_input = speech_conditioning_input.squeeze(1) + conds = self.gst_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 1, d) + else: + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + conds = conds.mean(dim=1) + conds = conds.unsqueeze(1) + return conds + + def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths, + cond_mel_lengths=None, types=None, text_first=True, raw_mels=None, return_attentions=False, + return_latent=False, clip_inputs=False): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + (actuated by `text_first`). + + speech_conditioning_input: MEL float tensor, (b,1024) + text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + raw_mels: MEL float tensor (b,80,s) + + If return_attentions is specified, only logits are returned. + If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. + If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality. + """ + + speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths) + # Types are expressed by expanding the text embedding space. + if types is not None: + text_inputs = text_inputs * (1 + types).unsqueeze(-1) + + if clip_inputs: + # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by + # chopping the inputs by the maximum actual length. + max_text_len = text_lengths.max() + text_inputs = text_inputs[:, :max_text_len] + max_mel_len = wav_lengths.max() // self.mel_length_compression + mel_codes = mel_codes[:, :max_mel_len] + if raw_mels is not None: + raw_mels = raw_mels[:, :, :max_mel_len * 4] + + # Set padding areas within MEL (currently it is coded with the MEL code for ). + # mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc') + mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1 + mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths) + text_inputs = self.set_text_padding(text_inputs, text_lengths) + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token) + + conds = speech_conditioning_latent + text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) + if raw_mels is not None: + mel_inp = F.pad(raw_mels, (0, 8)) + else: + mel_inp = mel_codes + mel_emb = self.mel_embedding(mel_inp) + mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + + if text_first: + # print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}") + text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) + if return_latent: + return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + else: + mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent) + if return_latent: + return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + + if return_attentions: + return mel_logits + + loss_text = F.cross_entropy(text_logits, text_targets.long()) + loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) + return loss_text.mean(), loss_mel.mean(), mel_logits + + def prepare_gpt_inputs( + self, + conditional_latents: torch.Tensor, + text_inputs: torch.Tensor, + ): + + """ + Prepare the inputs for the GPT2InferenceModel to generate. + Args: + conds_latent: (b, 32, dim) audio conditioning embedding by `get_conditioning()` + text_inputs: (b, L) + Returns: + input_ids: (b, s+1) the input ids for the GPT2InferenceModel.generate() + inputs_embeds: (b, s+1, dim) the input embeddings for the GPT2InferenceModel.forward() + attention_mask: (b, s+1) the attention mask for the GPT2InferenceModel.generate() + """ + b, L = text_inputs.shape[:2] + device = text_inputs.device + single_cond = conditional_latents.ndim == 3 and conditional_latents.shape[0] == 1 + if not single_cond: + assert conditional_latents.shape[0] == b, f"batch size mismatch: {conditional_latents.shape[0]} vs {b}" + batched_mel_emb = [] + attention_masks = [] + target_len = conditional_latents.shape[1] + L + 2 + for i in range(b): + valid_mask = (text_inputs[i] != self.stop_text_token) & (text_inputs[i] != self.start_text_token) + text_input = text_inputs[i][valid_mask] + text_input = F.pad(text_input, (1, 0), value=self.start_text_token) + text_input = F.pad(text_input, (0, 1), value=self.stop_text_token) + text_input_pos = torch.arange(0, text_input.size(-1), device=device) + text_emb = self.text_embedding(text_input) + self.text_pos_embedding.emb(text_input_pos) + # concatenate [conditional latents][text embeddings] + conds_text_emb = [ + conditional_latents.squeeze(0) if single_cond else conditional_latents[i], + text_emb, + ] + # +1 for the start_mel_token + attention_mask = torch.ones(target_len+1, dtype=torch.long, device=device) + # check this text input is padded + padding: int = L + 2 - text_input.size(-1) + # pad left of [cond][text] -> [pad][cond][text] + if padding > 0: + pad = torch.zeros((padding, conditional_latents.size(-1)), dtype=text_emb.dtype, device=device) # [p, dim] + conds_text_emb.insert(0, pad) + attention_mask[:padding] = 0 + mel_emb = torch.cat(conds_text_emb) #[s, dim] + assert mel_emb.shape[0] == target_len, f"mel_emb.shape: {mel_emb.shape}, target_len: {target_len}" + batched_mel_emb.append(mel_emb) + attention_masks.append(attention_mask) + # [b, s, dim] + batched_mel_emb = torch.stack(batched_mel_emb, dim=0) + # [b, s+1] + attention_mask = torch.stack(attention_masks, dim=0) + # [b, s+1] + fake_inputs = torch.ones( + ( + batched_mel_emb.shape[0], + batched_mel_emb.shape[1] + 1, # +1 for the start_mel_token + ), + dtype=torch.long, + device=device, + ) + fake_inputs[:, -1] = self.start_mel_token + return fake_inputs, batched_mel_emb, attention_mask + def inference_speech(self, speech_conditioning_mel, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1, + max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): + """ + Args: + speech_conditioning_mel: (b, n_mels, frames) or (n_mels, frames) + text_inputs: (b, L) + cond_mel_lengths: lengths of the conditioning mel spectrograms in shape (b,) or (1,) + input_tokens: additional tokens for generation in shape (b, s) or (s,) + max_generate_length: limit the number of generated tokens + hf_generate_kwargs: kwargs for `GPT2InferenceModel.generate(**hf_generate_kwargs)` + """ + if speech_conditioning_mel.ndim == 2: + speech_conditioning_mel = speech_conditioning_mel.unsqueeze(0) + if cond_mel_lengths is None: + cond_mel_lengths = torch.tensor([speech_conditioning_mel.shape[-1]], device=speech_conditioning_mel.device) + conds_latent = self.get_conditioning(speech_conditioning_mel, cond_mel_lengths) + input_ids, inputs_embeds, attention_mask = self.prepare_gpt_inputs(conds_latent, text_inputs) + self.inference_model.store_mel_emb(inputs_embeds) + if input_tokens is None: + inputs = input_ids + else: + if input_tokens.ndim == 1: + input_tokens = input_tokens.unsqueeze(0) + assert num_return_sequences % input_tokens.shape[0] == 0, \ + "The num_return_sequences must be divisible by the batch number of input_tokens" + assert num_return_sequences % text_inputs.shape[0] == 0, \ + "The num_return_sequences must be divisible by the batch number of text_inputs" + b = num_return_sequences // input_ids.shape[0] + if b > 1: + input_ids = input_ids.repeat(b, 1) + attention_mask = attention_mask.repeat(b, 1) + input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1) + inputs = torch.cat([input_ids, input_tokens], dim=1) + attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1) + trunc_index = inputs.shape[1] + logits_processor = LogitsProcessorList() + if typical_sampling: + # employ custom typical sampling + if not (typical_mass > 0.0 and typical_mass < 1.0): + raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}") + min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1 + logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep)) + max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length + output = self.inference_model.generate(inputs, + bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, attention_mask=attention_mask, + max_length=max_length, logits_processor=logits_processor, + num_return_sequences=num_return_sequences, + **hf_generate_kwargs) + if isinstance(output, torch.Tensor): + return output[:, trunc_index:] + # GenerateOutput + output.sequences = output.sequences[:, trunc_index:] + return output diff --git a/indextts/gpt/model_v2.py b/indextts/gpt/model_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..3f39becc39c156fb291f0209849a38cbf4007b5f --- /dev/null +++ b/indextts/gpt/model_v2.py @@ -0,0 +1,747 @@ +import functools + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import transformers +from transformers import GPT2Config, LogitsProcessorList +from indextts.gpt.transformers_gpt2 import GPT2PreTrainedModel, GPT2Model + +# from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions +from transformers.utils.model_parallel_utils import (assert_device_map, + get_device_map) + +from indextts.gpt.conformer_encoder import ConformerEncoder +from indextts.gpt.perceiver import PerceiverResampler +from indextts.utils.arch_util import AttentionBlock +from indextts.utils.typical_sampling import TypicalLogitsWarper + + +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + +class ResBlock(nn.Module): + """ + Basic residual convolutional block that uses GroupNorm. + """ + + def __init__(self, chan): + super().__init__() + self.net = nn.Sequential( + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan // 8, chan), + nn.ReLU(), + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan // 8, chan) + ) + + def forward(self, x): + return F.relu(self.net(x) + x) + + +class GPT2InferenceModel(GPT2PreTrainedModel): + def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False): + super().__init__(config) + # Note: the argument named `text_pos_emb` here actually represents the mel position embedding + self.transformer = gpt + self.text_pos_embedding = text_pos_emb + self.embeddings = embeddings + self.final_norm = norm + self.lm_head = nn.Sequential(norm, linear) + self.kv_cache = kv_cache + + # Model parallel + self.model_parallel = False + self.device_map = None + self.cached_mel_emb = None + + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count()))) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def store_mel_emb(self, mel_emb): + self.cached_mel_emb = mel_emb + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # usually None + if not self.kv_cache: + past_key_values = None + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + assert self.cached_mel_emb is not None + assert inputs_embeds is None # Not supported by this inference model. + assert labels is None # Training not supported by this inference model. + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + # Create embedding + mel_len = self.cached_mel_emb.shape[1] + if input_ids.shape[1] != 1: + text_inputs = input_ids[:, mel_len:] + text_emb = self.embeddings(text_inputs) + text_emb = text_emb + self.text_pos_embedding(text_emb) + if self.cached_mel_emb.shape[0] != text_emb.shape[0]: + mel_emb = self.cached_mel_emb.repeat_interleave( + text_emb.shape[0] // self.cached_mel_emb.shape[0], 0 + ) + else: # this outcome only occurs once per loop in most cases + mel_emb = self.cached_mel_emb + emb = torch.cat([mel_emb, text_emb], dim=1) + else: + emb = self.embeddings(input_ids) + emb = emb + self.text_pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - mel_len, attention_mask.device + ) + transformer_outputs = self.transformer( + inputs_embeds=emb, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + if torch.backends.mps.is_available(): + self.to(self.transformer.first_device) + else: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + transformer_outputs[1:] + + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past, beam_idx): + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ) + for layer_past in past + ) + + +class ConditioningEncoder(nn.Module): + def __init__(self, + spec_dim, + embedding_dim, + attn_blocks=6, + num_attn_heads=4, + do_checkpointing=False, + mean=False): + super().__init__() + attn = [] + self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + self.do_checkpointing = do_checkpointing + self.mean = mean + + def forward(self, x): + h = self.init(x) + h = self.attn(h) + if self.mean: + return h.mean(dim=2) + else: + return h + # return h[:, :, 0] + + +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len, model_dim, init=.02): + super().__init__() + self.emb = nn.Embedding(seq_len, model_dim) + # Initializing this way is standard for GPT-2 + self.emb.weight.data.normal_(mean=0.0, std=init) + + def forward(self, x): + sl = x.shape[1] + return self.emb(torch.arange(0, sl, device=x.device)) + + def get_fixed_embedding(self, ind, dev): + return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) + + +def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): + """ + GPT-2 implemented by the HuggingFace library. + """ + from transformers import GPT2Config, GPT2Model + gpt_config = GPT2Config(vocab_size=256, # Unused. + n_positions=max_mel_seq_len + max_text_seq_len, + n_ctx=max_mel_seq_len + max_text_seq_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing) + gpt = GPT2Model(gpt_config) + # Override the built in positional embeddings + del gpt.wpe + gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + # Built-in token embeddings are unused. + del gpt.wte + return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \ + None, None + + +class MelEncoder(nn.Module): + def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2): + super().__init__() + self.channels = channels + self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1), + nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 16, channels // 2), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 8, channels), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), + ) + self.reduction = 4 + + def forward(self, x): + for e in self.encoder: + x = e(x) + return x.permute(0, 2, 1) + + +class UnifiedVoice(nn.Module): + def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1, + mel_length_compression=1024, number_text_tokens=256, + start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193, + train_solo_embeddings=False, use_mel_codes_as_input=True, + checkpointing=True, types=1, + condition_num_latent=32, condition_type="perceiver", condition_module=None, emo_condition_module=None): + """ + Args: + layers: Number of layers in transformer stack. + model_dim: Operating dimensions of the transformer + heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64 + max_text_tokens: Maximum number of text tokens that will be encountered by model. + max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. + max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). + mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. + number_text_tokens: + start_text_token: + stop_text_token: + number_mel_codes: + start_mel_token: + stop_mel_token: + train_solo_embeddings: + use_mel_codes_as_input: + checkpointing: + condition_type: perceiver, gst or default encoder + """ + super().__init__() + self.number_text_tokens = number_text_tokens + self.start_text_token = start_text_token + self.stop_text_token = stop_text_token + self.number_mel_codes = number_mel_codes + self.start_mel_token = start_mel_token + self.stop_mel_token = stop_mel_token + self.layers = layers + self.heads = heads + self.max_mel_tokens = max_mel_tokens + self.max_text_tokens = max_text_tokens + self.model_dim = model_dim + self.max_conditioning_inputs = max_conditioning_inputs + self.mel_length_compression = mel_length_compression + self.condition_type = condition_type + self.cond_num = condition_num_latent + self.cond_mask_pad = nn.ConstantPad1d((self.cond_num, 0), True) + self.emo_cond_mask_pad = nn.ConstantPad1d((1, 0), True) + if condition_type == "perceiver": + self.conditioning_encoder = ConditioningEncoder(1024, model_dim, num_attn_heads=heads) + self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=model_dim, num_latents=self.cond_num) + elif condition_type == "conformer_perceiver" or condition_type == "conformer_encoder": + self.conditioning_encoder = ConformerEncoder(input_size=1024, + output_size=condition_module['output_size'], + linear_units=condition_module['linear_units'], + attention_heads=condition_module['attention_heads'], + num_blocks=condition_module['num_blocks'], + input_layer=condition_module['input_layer']) + if condition_type == "conformer_perceiver": + self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=condition_module['output_size'], + ff_mult=condition_module['perceiver_mult'], + heads=condition_module['attention_heads'], + num_latents=self.cond_num) + else: + self.conditioning_encoder = ConditioningEncoder(1024, model_dim, num_attn_heads=heads, mean=True) + + self.emo_conditioning_encoder = ConformerEncoder(input_size=1024, + output_size=emo_condition_module['output_size'], + linear_units=emo_condition_module['linear_units'], + attention_heads=emo_condition_module['attention_heads'], + num_blocks=emo_condition_module['num_blocks'], + input_layer=emo_condition_module['input_layer']) + self.emo_perceiver_encoder = PerceiverResampler(1024, dim_context=emo_condition_module['output_size'], + ff_mult=emo_condition_module['perceiver_mult'], + heads=emo_condition_module['attention_heads'], + num_latents=1) + + + + self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim) + self.emo_layer = nn.Linear(model_dim, model_dim) + self.emovec_layer = nn.Linear(1024, model_dim) + + if use_mel_codes_as_input: + self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) + else: + self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) + self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ + build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs, + self.max_text_tokens + 2, checkpointing) + if train_solo_embeddings: + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) + else: + self.mel_solo_embedding = 0 + self.text_solo_embedding = 0 + + self.final_norm = nn.LayerNorm(model_dim) + self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1) + self.mel_head = nn.Linear(model_dim, self.number_mel_codes) + + self.speed_emb = nn.Embedding(2, model_dim) + self.speed_emb.weight.data.normal_(mean=0.0, std=0.0) + + # Initialize the embeddings per the GPT-2 scheme + embeddings = [self.text_embedding] + if use_mel_codes_as_input: + embeddings.append(self.mel_embedding) + for module in embeddings: + module.weight.data.normal_(mean=0.0, std=.02) + + def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False): + seq_length = self.max_mel_tokens + self.max_text_tokens + 2 + gpt_config = GPT2Config( + vocab_size=self.number_mel_codes, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + ) + self.inference_model = GPT2InferenceModel( + gpt_config, + self.gpt, + self.mel_pos_embedding, + self.mel_embedding, + self.final_norm, + self.mel_head, + kv_cache=kv_cache, + ) + if use_deepspeed and half and torch.cuda.is_available(): + import deepspeed + self.ds_engine = deepspeed.init_inference(model=self.inference_model, + mp_size=1, + replace_with_kernel_inject=True, + dtype=torch.float16) + self.inference_model = self.ds_engine.module.eval() + elif use_deepspeed and torch.cuda.is_available(): + import deepspeed + self.ds_engine = deepspeed.init_inference(model=self.inference_model, + mp_size=1, + replace_with_kernel_inject=True, + dtype=torch.float32) + self.inference_model = self.ds_engine.module.eval() + else: + self.inference_model = self.inference_model.eval() + + # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) + self.gpt.wte = self.mel_embedding + + def build_aligned_inputs_and_targets(self, input, start_token, stop_token): + inp = F.pad(input, (1, 0), value=start_token) + tar = F.pad(input, (0, 1), value=stop_token) + return inp, tar + + def set_mel_padding(self, mel_input_tokens, mel_lengths): + """ + Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in + that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required + preformatting to create a working TTS model. + """ + for b in range(len(mel_lengths)): + # Due to the convolutional nature of how these tokens are generated, + # it would be best if the model predicts a token past the actual last token. + actual_end = mel_lengths[b] + if actual_end < mel_input_tokens.shape[-1]: + mel_input_tokens[b, actual_end:] = self.stop_mel_token + return mel_input_tokens + + def set_text_padding(self, text_input_tokens, text_lengths): + """ + Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in + that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required + preformatting to create a working TTS model. + """ + for b in range(len(text_lengths)): + # Due to the convolutional nature of how these tokens are generated, + # it would be best if the model predicts a token past the actual last token. + actual_end = text_lengths[b] + if actual_end < text_input_tokens.shape[-1]: + text_input_tokens[b, actual_end:] = self.stop_text_token + return text_input_tokens + + def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False): + if second_inputs is not None: + emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) + + gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) + if get_attns: + return gpt_out.attentions + + offset = speech_conditioning_inputs.shape[1] + enc = gpt_out.last_hidden_state[:, offset:] + enc = self.final_norm(enc) + + if return_latent: + return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:] + + first_logits = enc[:, :first_inputs.shape[1]] + first_logits = first_head(first_logits) + first_logits = first_logits.permute(0, 2, 1) + if second_inputs is not None: + second_logits = enc[:, -second_inputs.shape[1]:] + second_logits = second_head(second_logits) + second_logits = second_logits.permute(0, 2, 1) + return first_logits, second_logits + else: + return first_logits + + def get_conditioning(self, speech_conditioning_input, cond_mel_lengths=None): + if self.condition_type == "perceiver": + if speech_conditioning_input.ndim == 4: + speech_conditioning_input = speech_conditioning_input.squeeze(1) + speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input) # (b, d, s) + conds = self.perceiver_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 32, d) + elif self.condition_type == "conformer_perceiver": + speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2), + cond_mel_lengths) # (b, s, d), (b, 1, s) + if self.condition_type == "conformer_perceiver": + # conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1) + conds_mask = self.cond_mask_pad(mask.squeeze(1)) + conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d) + elif self.condition_type == "gst": + if speech_conditioning_input.ndim == 4: + speech_conditioning_input = speech_conditioning_input.squeeze(1) + conds = self.gst_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 1, d) + else: + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + conds = conds.mean(dim=1) + conds = conds.unsqueeze(1) + return conds + + + def get_emo_conditioning(self, speech_conditioning_input, cond_mel_lengths=None): + speech_conditioning_input, mask = self.emo_conditioning_encoder(speech_conditioning_input.transpose(1, 2), + cond_mel_lengths) # (b, s, d), (b, 1, s) + conds_mask = self.emo_cond_mask_pad(mask.squeeze(1)) + conds = self.emo_perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 1, d) + return conds.squeeze(1) + + + def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, mel_codes_lengths, emo_speech_conditioning_latent, + cond_mel_lengths=None, emo_cond_mel_lengths=None, emo_vec=None, use_speed=None, do_spk_cond=False): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + + speech_conditioning_input: MEL float tensor, (b,1024) + text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + + If return_attentions is specified, only logits are returned. + If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. + """ + + if do_spk_cond: + speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent.transpose(1,2), cond_mel_lengths) + else: + speech_conditioning_latent = speech_conditioning_latent + + if emo_vec is None: + emo_vec_syn_ori = self.get_emo_conditioning(emo_speech_conditioning_latent.transpose(1,2), emo_cond_mel_lengths) + emo_vec_syn = self.emovec_layer(emo_vec_syn_ori) + emo_vec = self.emo_layer(emo_vec_syn) + + text_inputs = self.set_text_padding(text_inputs, text_lengths) + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + + mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths) + mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token) + + duration_emb = self.speed_emb(torch.zeros_like(use_speed)) + duration_emb_half = self.speed_emb(torch.ones_like(use_speed)) + conds = torch.cat((speech_conditioning_latent + emo_vec.unsqueeze(1), duration_emb_half.unsqueeze(1), duration_emb.unsqueeze(1)), 1) + text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) + + mel_emb = self.mel_embedding(mel_codes) + mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + + text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=False, return_latent=True) + return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + + def prepare_gpt_inputs( + self, + conditional_latents: torch.Tensor, + text_inputs: torch.Tensor, + ): + + """ + Prepare the inputs for the GPT2InferenceModel to generate. + Args: + conds_latent: (b, 32, dim) audio conditioning embedding by `get_conditioning()` + text_inputs: (b, L) + Returns: + input_ids: (b, s+1) the input ids for the GPT2InferenceModel.generate() + inputs_embeds: (b, s+1, dim) the input embeddings for the GPT2InferenceModel.forward() + attention_mask: (b, s+1) the attention mask for the GPT2InferenceModel.generate() + """ + b, L = text_inputs.shape[:2] + device = text_inputs.device + single_cond = conditional_latents.ndim == 3 and conditional_latents.shape[0] == 1 + if not single_cond: + assert conditional_latents.shape[0] == b, f"batch size mismatch: {conditional_latents.shape[0]} vs {b}" + batched_mel_emb = [] + attention_masks = [] + target_len = conditional_latents.shape[1] + L + 2 + for i in range(b): + valid_mask = (text_inputs[i] != self.stop_text_token) & (text_inputs[i] != self.start_text_token) + text_input = text_inputs[i][valid_mask] + text_input = F.pad(text_input, (1, 0), value=self.start_text_token) + text_input = F.pad(text_input, (0, 1), value=self.stop_text_token) + text_input_pos = torch.arange(0, text_input.size(-1), device=device) + text_emb = self.text_embedding(text_input) + self.text_pos_embedding.emb(text_input_pos) + # concatenate [conditional latents][text embeddings] + conds_text_emb = [ + conditional_latents.squeeze(0) if single_cond else conditional_latents[i], + text_emb, + ] + # +1 for the start_mel_token + attention_mask = torch.ones(target_len+1, dtype=torch.long, device=device) + # check this text input is padded + padding: int = L + 2 - text_input.size(-1) + # pad left of [cond][text] -> [pad][cond][text] + if padding > 0: + pad = torch.zeros((padding, conditional_latents.size(-1)), dtype=text_emb.dtype, device=device) # [p, dim] + conds_text_emb.insert(0, pad) + attention_mask[:padding] = 0 + mel_emb = torch.cat(conds_text_emb) #[s, dim] + assert mel_emb.shape[0] == target_len, f"mel_emb.shape: {mel_emb.shape}, target_len: {target_len}" + batched_mel_emb.append(mel_emb) + attention_masks.append(attention_mask) + # [b, s, dim] + batched_mel_emb = torch.stack(batched_mel_emb, dim=0) + # [b, s+1] + attention_mask = torch.stack(attention_masks, dim=0) + # [b, s+1] + fake_inputs = torch.ones( + ( + batched_mel_emb.shape[0], + batched_mel_emb.shape[1] + 1, # +1 for the start_mel_token + ), + dtype=torch.long, + device=device, + ) + fake_inputs[:, -1] = self.start_mel_token + return fake_inputs, batched_mel_emb, attention_mask + + def inference_speech(self, speech_condition, text_inputs, emo_speech_condition=None, cond_lengths=None, emo_cond_lengths=None, emo_vec=None, use_speed=False, input_tokens=None, num_return_sequences=1, + max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): + """ + Args: + speech_condition: (b, d, frames) or (d, frames) + text_inputs: (b, L) + cond_mel_lengths: lengths of the conditioning mel spectrograms in shape (b,) or (1,) + input_tokens: additional tokens for generation in shape (b, s) or (s,) + max_generate_length: limit the number of generated tokens + hf_generate_kwargs: kwargs for `GPT2InferenceModel.generate(**hf_generate_kwargs)` + """ + + if speech_condition.ndim == 2: + speech_condition = speech_condition.unsqueeze(0) + if emo_speech_condition is None: + emo_speech_condition = speech_condition + if cond_lengths is None: + cond_lengths = torch.tensor([speech_condition.shape[-1]], device=speech_condition.device) + if emo_cond_lengths is None: + emo_cond_lengths = torch.tensor([emo_speech_condition.shape[-1]], device=speech_condition.device) + + speech_conditioning_latent = self.get_conditioning(speech_condition.transpose(1,2), cond_lengths) + if emo_vec is None: + print('compute emo vec') + emo_vec = self.get_emo_conditioning(emo_speech_condition.transpose(1,2), emo_cond_lengths) + emo_vec = self.emovec_layer(emo_vec) + emo_vec = self.emo_layer(emo_vec) + else: + print('Use the specified emotion vector') + + tmp = torch.zeros(text_inputs.size(0)).to(text_inputs.device) + duration_emb = self.speed_emb(torch.zeros_like(tmp).long()) + duration_emb_half = self.speed_emb(torch.ones_like(tmp).long()) + conds_latent = torch.cat((speech_conditioning_latent + emo_vec.unsqueeze(1), duration_emb_half.unsqueeze(1), duration_emb.unsqueeze(1)), 1) + input_ids, inputs_embeds, attention_mask = self.prepare_gpt_inputs(conds_latent, text_inputs) + self.inference_model.store_mel_emb(inputs_embeds) + if input_tokens is None: + inputs = input_ids + else: + if input_tokens.ndim == 1: + input_tokens = input_tokens.unsqueeze(0) + assert num_return_sequences % input_tokens.shape[0] == 0, \ + "The num_return_sequences must be divisible by the batch number of input_tokens" + assert num_return_sequences % text_inputs.shape[0] == 0, \ + "The num_return_sequences must be divisible by the batch number of text_inputs" + b = num_return_sequences // input_ids.shape[0] + if b > 1: + input_ids = input_ids.repeat(b, 1) + attention_mask = attention_mask.repeat(b, 1) + input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1) + inputs = torch.cat([input_ids, input_tokens], dim=1) + attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1) + trunc_index = inputs.shape[1] + logits_processor = LogitsProcessorList() + if typical_sampling: + # employ custom typical sampling + if not (typical_mass > 0.0 and typical_mass < 1.0): + raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}") + min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1 + logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep)) + max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length + output = self.inference_model.generate(inputs, + bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, attention_mask=attention_mask, + max_length=max_length, logits_processor=logits_processor, + num_return_sequences=num_return_sequences, + **hf_generate_kwargs) + if isinstance(output, torch.Tensor): + return output[:, trunc_index:], speech_conditioning_latent + # GenerateOutput + output.sequences = output.sequences[:, trunc_index:] + return output, speech_conditioning_latent + + def get_emovec(self, emo_speech_conditioning_latent, emo_cond_lengths): + emo_vec_syn_ori = self.get_emo_conditioning(emo_speech_conditioning_latent.transpose(1,2), emo_cond_lengths) + emo_vec_syn = self.emovec_layer(emo_vec_syn_ori) + emo_vec = self.emo_layer(emo_vec_syn) + return emo_vec + + def merge_emovec(self, speech_conditioning_latent, emo_speech_conditioning_latent, cond_lengths, emo_cond_lengths, alpha = 1.0): + emo_vec = self.get_emovec(emo_speech_conditioning_latent, emo_cond_lengths) + base_vec = self.get_emovec(speech_conditioning_latent, cond_lengths) + + out = base_vec + alpha * (emo_vec - base_vec) + return out diff --git a/indextts/gpt/perceiver.py b/indextts/gpt/perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..f5e9c30417901e9eb73ee41731cfade3910cc821 --- /dev/null +++ b/indextts/gpt/perceiver.py @@ -0,0 +1,317 @@ +# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532 + +from collections import namedtuple +from functools import wraps + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from packaging import version +from torch import einsum, nn + + +def exists(val): + return val is not None + + +def once(fn): + called = False + + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + + return inner + + +print_once = once(print) + + +# main class +class Attend(nn.Module): + def __init__(self, dropout=0.0, causal=False, use_flash=False): + super().__init__() + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.causal = causal + self.register_buffer("mask", None, persistent=False) + + self.use_flash = use_flash + assert not ( + use_flash and version.parse(torch.__version__) < version.parse("2.0.0") + ), "in order to use flash attention, you must be using pytorch 2.0 or above" + + # determine efficient attention configs for cuda and cpu + self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]) + self.cpu_config = self.config(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) + + if device_properties.major == 8 and device_properties.minor == 0: + print_once("A100 GPU detected, using flash attention if input tensor is on cuda") + self.cuda_config = self.config(True, False, False) + else: + print_once("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda") + self.cuda_config = self.config(False, True, True) + + def get_mask(self, n, device): + if exists(self.mask) and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) + self.register_buffer("mask", mask, persistent=False) + return mask + + def flash_attn(self, q, k, v, mask=None): + _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda + + # Recommended for multi-query single-key-value attention by Tri Dao + # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) + + if k.ndim == 3: + k = rearrange(k, "b ... -> b 1 ...").expand_as(q) + + if v.ndim == 3: + v = rearrange(v, "b ... -> b 1 ...").expand_as(q) + + # Check if mask exists and expand to compatible shape + # The mask is B L, so it would have to be expanded to B H N L + + if exists(mask): + mask = rearrange(mask, "b j -> b 1 1 j") + mask = mask.expand(-1, heads, q_len, -1) + + # Check if there is a compatible device for flash attention + + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal + ) + + return out + + def forward(self, q, k, v, mask=None): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + n, device = q.shape[-2], q.device + + scale = q.shape[-1] ** -0.5 + + if self.use_flash: + return self.flash_attn(q, k, v, mask=mask) + + kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d" + + # similarity + + sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale + + # key padding mask + + if exists(mask): + mask = rearrange(mask, "b j -> b 1 1 j") + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # causal mask + + if self.causal: + causal_mask = self.get_mask(n, device) + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) + + return out + + +def Sequential(*mods): + return nn.Sequential(*filter(exists, mods)) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +class RMSNorm(nn.Module): + def __init__(self, dim, scale=True, dim_cond=None): + super().__init__() + self.cond = exists(dim_cond) + self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None + + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) if scale else None + + def forward(self, x, cond=None): + gamma = default(self.gamma, 1) + out = F.normalize(x, dim=-1) * self.scale * gamma + + if not self.cond: + return out + + assert exists(cond) + gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1) + gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta)) + return out * gamma + beta + + +class CausalConv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + (kernel_size,) = self.kernel_size + (dilation,) = self.dilation + (stride,) = self.stride + + assert stride == 1 + self.causal_padding = dilation * (kernel_size - 1) + + def forward(self, x): + causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0) + return super().forward(causal_padded_x) + + +class GEGLU(nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.gelu(gate) * x + + +def FeedForward(dim, mult=4, causal_conv=False): + dim_inner = int(dim * mult * 2 / 3) + + conv = None + if causal_conv: + conv = nn.Sequential( + Rearrange("b n d -> b d n"), + CausalConv1d(dim_inner, dim_inner, 3), + Rearrange("b d n -> b n d"), + ) + + return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)) + + +class PerceiverResampler(nn.Module): + def __init__( + self, + dim, + depth=2, + dim_context=None, + num_latents=32, + dim_head=64, + heads=8, + ff_mult=4, + use_flash_attn=False, + ): + super().__init__() + dim_context = default(dim_context, dim) + + self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity() + + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + nn.init.normal_(self.latents, std=0.02) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + Attention( + dim=dim, + dim_head=dim_head, + heads=heads, + use_flash=use_flash_attn, + cross_attn_include_queries=True, + ), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + self.norm = RMSNorm(dim) + + def forward(self, x, mask=None): + batch = x.shape[0] + + x = self.proj_context(x) + + latents = repeat(self.latents, "n d -> b n d", b=batch) + + for attn, ff in self.layers: + latents = attn(latents, x, mask=mask) + latents + latents = ff(latents) + latents + + return self.norm(latents) + + +class Attention(nn.Module): + def __init__( + self, + dim, + *, + dim_context=None, + causal=False, + dim_head=64, + heads=8, + dropout=0.0, + use_flash=False, + cross_attn_include_queries=False, + ): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + self.cross_attn_include_queries = cross_attn_include_queries + + dim_inner = dim_head * heads + dim_context = default(dim_context, dim) + + self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash) + self.to_q = nn.Linear(dim, dim_inner, bias=False) + self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False) + self.to_out = nn.Linear(dim_inner, dim, bias=False) + + def forward(self, x, context=None, mask=None): + h, has_context = self.heads, exists(context) + + context = default(context, x) + + if has_context and self.cross_attn_include_queries: + context = torch.cat((x, context), dim=-2) + + q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = self.attend(q, k, v, mask=mask) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) diff --git a/indextts/gpt/transformers_beam_search.py b/indextts/gpt/transformers_beam_search.py new file mode 100755 index 0000000000000000000000000000000000000000..d82e081de84a7240f17ff1a4bcd0d218ba2147ce --- /dev/null +++ b/indextts/gpt/transformers_beam_search.py @@ -0,0 +1,1013 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from collections import UserDict +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from transformers.utils import add_start_docstrings +from transformers.generation.beam_constraints import Constraint, ConstraintListState + + +PROCESS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`): + Current scores of the top `2 * num_beams` non-finished beam hypotheses. + next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`): + `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses. + next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`): + Beam indices indicating to which beam hypothesis the `next_tokens` correspond. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + beam_indices (`torch.LongTensor`, *optional*): + Beam indices indicating to which beam hypothesis each token correspond. + group_index (`int`, *optional*): + The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`]. + + Return: + `UserDict`: A dictionary composed of the fields as defined above: + + - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all + non-finished beams. + - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added + to the non-finished beam_hypotheses. + - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices + indicating to which beam the next tokens shall be added. + +""" + +FINALIZE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + final_beam_scores (`torch.FloatTensor` of shape `(batch_size * num_beams)`): + The final scores of all non-finished beams. + final_beam_tokens (`torch.FloatTensor` of shape `(batch_size * num_beams)`): + The last tokens to be added to the non-finished beam_hypotheses. + final_beam_indices (`torch.FloatTensor` of shape `(batch_size * num_beams)`): + The beam indices indicating to which beam the `final_beam_tokens` shall be added. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + + Return: + `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. + The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early + due to the `eos_token_id`. + +""" + + +class BeamScorer(ABC): + """ + Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and + [`~PreTrainedModel.beam_sample`]. + """ + + @abstractmethod + @add_start_docstrings(PROCESS_INPUTS_DOCSTRING) + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + **kwargs, + ) -> Tuple[torch.Tensor]: + raise NotImplementedError("This is an abstract method.") + + @abstractmethod + @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING) + def finalize( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + max_length: int, + **kwargs, + ) -> torch.LongTensor: + raise NotImplementedError("This is an abstract method.") + + +class BeamSearchScorer(BeamScorer): + r""" + [`BeamScorer`] implementing standard beam search decoding. + + Adapted in part from [Facebook's XLM beam search + code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529). + + Reference for the diverse beam search algorithm and implementation [Ashwin Kalyan's DBS + implementation](https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua) + + Args: + batch_size (`int`): + Batch Size of `input_ids` for which standard beam search decoding is run in parallel. + num_beams (`int`): + Number of beams for beam search. + device (`torch.device`): + Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be + allocated. + length_penalty (`float`, *optional*, defaults to 1.0): + Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to + the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log + likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while + `length_penalty` < 0.0 encourages shorter sequences. + do_early_stopping (`bool` or `str`, *optional*, defaults to `False`): + Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: + `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an + heuristic is applied and the generation stops when is it very unlikely to find better candidates; + `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical + beam search algorithm). + num_beam_hyps_to_keep (`int`, *optional*, defaults to 1): + The number of beam hypotheses that shall be returned upon calling + [`~transformers.BeamSearchScorer.finalize`]. + num_beam_groups (`int`, *optional*, defaults to 1): + Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. + See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. + max_length (`int`, *optional*): + The maximum length of the sequence to be generated. + """ + + def __init__( + self, + batch_size: int, + num_beams: int, + device: torch.device, + length_penalty: Optional[float] = 1.0, + do_early_stopping: Optional[Union[bool, str]] = False, + num_beam_hyps_to_keep: Optional[int] = 1, + num_beam_groups: Optional[int] = 1, + max_length: Optional[int] = None, + ): + self.num_beams = num_beams + self.device = device + self.length_penalty = length_penalty + self.do_early_stopping = do_early_stopping + self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + self.num_beam_groups = num_beam_groups + self.group_size = self.num_beams // self.num_beam_groups + + self._is_init = False + # self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch. + # If group_beam_search is not used, the list consists of `batch_size` beam_hyps. + self._beam_hyps = [ + BeamHypotheses( + num_beams=self.group_size, + length_penalty=self.length_penalty, + early_stopping=self.do_early_stopping, + max_length=max_length, + ) + for _ in range(batch_size * self.num_beam_groups) + ] + # self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group + # in the i-th mini-batch is complete. + self._done = torch.tensor( + [False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device + ) + + if not isinstance(num_beams, int) or num_beams <= 1: + raise ValueError( + f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1," + " one should make use of `greedy_search` instead." + ) + + if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): + raise ValueError( + "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be" + f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." + ) + + @property + def is_done(self) -> bool: + return self._done.all() + + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + pad_token_id: Optional[Union[int, torch.Tensor]] = None, + eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None, + beam_indices: Optional[torch.LongTensor] = None, + group_index: Optional[int] = 0, + decoder_prompt_len: Optional[int] = 0, + ) -> Dict[str, torch.Tensor]: + # add up to the length which the next_scores is calculated on (including decoder prompt) + cur_len = input_ids.shape[-1] + 1 + batch_size = len(self._beam_hyps) // self.num_beam_groups + + if not (batch_size == (input_ids.shape[0] // self.group_size)): + if self.num_beam_groups > 1: + raise ValueError( + f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam " + f"size of {self.group_size} is expected by the beam scorer." + ) + else: + raise ValueError( + f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of " + f"{self.group_size} is expected by the beam scorer." + ) + + device = input_ids.device + next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) + next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) + next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) + + if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) + + for batch_idx in range(batch_size): + batch_group_idx = batch_idx * self.num_beam_groups + group_index + if self._done[batch_group_idx]: + if self.num_beams < len(self._beam_hyps[batch_group_idx]): + raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated") + if eos_token_id is None or pad_token_id is None: + raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined") + # pad the batch + next_beam_scores[batch_idx, :] = 0 + next_beam_tokens[batch_idx, :] = pad_token_id + next_beam_indices[batch_idx, :] = 0 + continue + + # next tokens for this sentence + beam_idx = 0 + for beam_token_rank, (next_token, next_score, next_index) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) + ): + batch_beam_idx = batch_idx * self.group_size + next_index + # add to generated hypotheses if end of sentence + if (eos_token_id is not None) and (next_token.item() in eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size + if is_beam_token_worse_than_top_num_beams: + continue + if beam_indices is not None: + beam_index = beam_indices[batch_beam_idx] + beam_index = beam_index + (batch_beam_idx,) + else: + beam_index = None + + self._beam_hyps[batch_group_idx].add( + input_ids[batch_beam_idx].clone(), + next_score.item(), + beam_indices=beam_index, + generated_len=cur_len - decoder_prompt_len, + ) + else: + # add next predicted token since it is not eos_token + next_beam_scores[batch_idx, beam_idx] = next_score + next_beam_tokens[batch_idx, beam_idx] = next_token + next_beam_indices[batch_idx, beam_idx] = batch_beam_idx + beam_idx += 1 + + # once the beam for next step is full, don't add more tokens to it. + if beam_idx == self.group_size: + break + + if beam_idx < self.group_size: + raise ValueError( + f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:" + f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." + ) + + # Check if we are done so that we can save a pad step if all(done) + self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done( + next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len + ) + + return UserDict( + { + "next_beam_scores": next_beam_scores.view(-1), + "next_beam_tokens": next_beam_tokens.view(-1), + "next_beam_indices": next_beam_indices.view(-1), + } + ) + + def finalize( + self, + input_ids: torch.LongTensor, + final_beam_scores: torch.FloatTensor, + final_beam_tokens: torch.LongTensor, + final_beam_indices: torch.LongTensor, + max_length: int, + pad_token_id: Optional[Union[int, torch.Tensor]] = None, + eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None, + beam_indices: Optional[torch.LongTensor] = None, + decoder_prompt_len: Optional[int] = 0, + ) -> Tuple[torch.LongTensor]: + batch_size = len(self._beam_hyps) // self.num_beam_groups + + if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) + + # finalize all open beam hypotheses and add to generated hypotheses + for batch_group_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_group_idx]: + continue + + # all open beam hypotheses are added to the beam hypothesis + # beam hypothesis class automatically keeps the best beams + for index_per_group in range(self.group_size): + batch_beam_idx = batch_group_idx * self.group_size + index_per_group + final_score = final_beam_scores[batch_beam_idx].item() + final_tokens = input_ids[batch_beam_idx] + beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None + generated_len = final_tokens.shape[-1] - decoder_prompt_len + beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len) + + # select the best hypotheses + sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) + best = [] + best_indices = [] + best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) + + # retrieve best hypotheses + for i in range(batch_size): + beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups] + candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams] + sorted_hyps = sorted(candidate_beams, key=lambda x: x[0]) + for j in range(self.num_beam_hyps_to_keep): + best_hyp_tuple = sorted_hyps.pop() + best_score = best_hyp_tuple[0] + best_hyp = best_hyp_tuple[1] + best_index = best_hyp_tuple[2] + sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) + + # append hyp to lists + best.append(best_hyp) + + # append indices to list + best_indices.append(best_index) + + best_scores[i * self.num_beam_hyps_to_keep + j] = best_score + + # prepare for adding eos + sent_lengths_max = sent_lengths.max().item() + 1 + sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max + decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) + + if len(best_indices) > 0 and best_indices[0] is not None: + indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) + else: + indices = None + + # shorter batches are padded if needed + if sent_lengths.min().item() != sent_lengths.max().item(): + if pad_token_id is None: + raise ValueError("`pad_token_id` has to be defined") + decoded.fill_(pad_token_id) + + if indices is not None: + indices.fill_(-1) + + # fill with hypotheses and eos_token_id if the latter fits in + for i, (hypo, best_idx) in enumerate(zip(best, best_indices)): + decoded[i, : sent_lengths[i]] = hypo + + if indices is not None: + indices[i, : len(best_idx)] = torch.tensor(best_idx) + + if sent_lengths[i] < sent_max_len: + # inserting only the first eos_token_id + decoded[i, sent_lengths[i]] = eos_token_id[0] + + return UserDict( + { + "sequences": decoded, + "sequence_scores": best_scores, + "beam_indices": indices, + } + ) + + +class ConstrainedBeamSearchScorer(BeamScorer): + r""" + [`BeamScorer`] implementing constrained beam search decoding. + + + Args: + batch_size (`int`): + Batch Size of `input_ids` for which standard beam search decoding is run in parallel. + num_beams (`int`): + Number of beams for beam search. + constraints (`List[Constraint]`): + A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation + output. For more information, the documentation of [`Constraint`] should be read. + device (`torch.device`): + Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be + allocated. + length_penalty (`float`, *optional*, defaults to 1.0): + Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to + the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log + likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while + `length_penalty` < 0.0 encourages shorter sequences. + do_early_stopping (`bool` or `str`, *optional*, defaults to `False`): + Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: + `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an + heuristic is applied and the generation stops when is it very unlikely to find better candidates; + `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical + beam search algorithm). + num_beam_hyps_to_keep (`int`, *optional*, defaults to 1): + The number of beam hypotheses that shall be returned upon calling + [`~transformers.BeamSearchScorer.finalize`]. + num_beam_groups (`int`, *optional*, defaults to 1): + Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. + See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. + max_length (`int`, *optional*): + The maximum length of the sequence to be generated. + """ + + def __init__( + self, + batch_size: int, + num_beams: int, + constraints: List[Constraint], + device: torch.device, + length_penalty: Optional[float] = 1.0, + do_early_stopping: Optional[Union[bool, str]] = False, + num_beam_hyps_to_keep: Optional[int] = 1, + num_beam_groups: Optional[int] = 1, + max_length: Optional[int] = None, + ): + self.num_beams = num_beams + self.device = device + self.length_penalty = length_penalty + self.do_early_stopping = do_early_stopping + self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + self.num_beam_groups = num_beam_groups + self.group_size = self.num_beams // self.num_beam_groups + self.constraints = constraints + + self._is_init = False + self._beam_hyps = [ + BeamHypotheses( + num_beams=self.num_beams, + length_penalty=self.length_penalty, + early_stopping=self.do_early_stopping, + max_length=max_length, + ) + for _ in range(batch_size) + ] + self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) + + if not isinstance(num_beams, int) or num_beams <= 1: + raise ValueError( + f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1," + " one should make use of `greedy_search` instead." + ) + + if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): + raise ValueError( + "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be" + f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." + ) + + @property + def is_done(self) -> bool: + return self._done.all() + + def make_constraint_states(self, n): + return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)] + + def check_completes_constraints(self, sequence): + new_state = self.make_constraint_states(1)[0] + new_state.reset(sequence) + return new_state.completed + + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + scores_for_all_vocab: torch.FloatTensor, + pad_token_id: Optional[Union[int, torch.Tensor]] = None, + eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None, + beam_indices: Optional[torch.LongTensor] = None, + decoder_prompt_len: Optional[int] = 0, + ) -> Tuple[torch.Tensor]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`): + Current scores of the top `2 * num_beams` non-finished beam hypotheses. + next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`): + `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses. + next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`): + Beam indices indicating to which beam hypothesis the `next_tokens` correspond. + scores_for_all_vocab (`torch.FloatTensor` of shape `(batch_size * num_beams, sequence_length)`): + The scores of all tokens in the vocabulary for each of the beam hypotheses. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + beam_indices (`torch.LongTensor`, *optional*): + Beam indices indicating to which beam hypothesis each token correspond. + decoder_prompt_len (`int`, *optional*): + The length of prompt that is included in the input to decoder. + Return: + `UserDict`: A dictionary composed of the fields as defined above: + + - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of + all + non-finished beams. + + - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be + added + to the non-finished beam_hypotheses. + - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices + indicating to which beam the next tokens shall be added. + """ + + # add up to the length which the next_scores is calculated on (including decoder prompt) + cur_len = input_ids.shape[-1] + 1 + batch_size = len(self._beam_hyps) + if not (batch_size == (input_ids.shape[0] // self.group_size)): + if self.num_beam_groups > 1: + raise ValueError( + f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam " + f"size of {self.group_size} is expected by the beam scorer." + ) + else: + raise ValueError( + f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of " + f"{self.group_size} is expected by the beam scorer." + ) + + device = input_ids.device + + next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) + next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) + next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) + + if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) + + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + if self.num_beams < len(beam_hyp): + raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated") + if eos_token_id is None or pad_token_id is None: + raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined") + # pad the batch + next_beam_scores[batch_idx, :] = 0 + next_beam_tokens[batch_idx, :] = pad_token_id + next_beam_indices[batch_idx, :] = 0 + continue + + # next tokens for this sentence. + beam_idx = 0 + for beam_token_rank, (next_token, next_score, next_index) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) + ): + batch_beam_idx = batch_idx * self.group_size + next_index + # add to generated hypotheses if end of sentence + if (eos_token_id is not None) and (next_token.item() in eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size + if is_beam_token_worse_than_top_num_beams: + continue + + completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].cpu().tolist()) + if completes_constraint: + if beam_indices is not None: + beam_index = beam_indices[batch_beam_idx] + beam_index = beam_index + (batch_beam_idx,) + else: + beam_index = None + + beam_hyp.add( + input_ids[batch_beam_idx].clone(), + next_score.item(), + beam_indices=beam_index, + generated_len=cur_len - decoder_prompt_len, + ) + else: + # add next predicted token since it is not eos_token + next_beam_scores[batch_idx, beam_idx] = next_score + next_beam_tokens[batch_idx, beam_idx] = next_token + next_beam_indices[batch_idx, beam_idx] = batch_beam_idx + beam_idx += 1 + + # once the beam for next step is full, don't add more tokens to it. + if beam_idx == self.group_size: + break + + new_scores, new_tokens, new_indices = self.step_sentence_constraint( + batch_idx, + input_ids, + scores_for_all_vocab, + next_beam_scores[batch_idx], + next_beam_tokens[batch_idx], + next_beam_indices[batch_idx], + ) + + next_beam_scores[batch_idx] = new_scores + next_beam_tokens[batch_idx] = new_tokens + next_beam_indices[batch_idx] = new_indices + + if beam_idx < self.group_size: + raise ValueError( + f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:" + f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." + ) + + # Check if we are done so that we can save a pad step if all(done) + self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( + next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len + ) + + return UserDict( + { + "next_beam_scores": next_beam_scores.view(-1), + "next_beam_tokens": next_beam_tokens.view(-1), + "next_beam_indices": next_beam_indices.view(-1), + } + ) + + def step_sentence_constraint( + self, + batch_idx: int, + input_ids: torch.LongTensor, + vocab_scores: torch.FloatTensor, + sent_beam_scores: torch.FloatTensor, + sent_beam_tokens: torch.LongTensor, + sent_beam_indices: torch.LongTensor, + push_progress: bool = False, + ): + # sent_beam_tokens are the next {num_beams} number of tokens that are under consideration for this beam + # (candidate next tokens) + + # 1. Adding "advance_tokens" + # using ConstraintStateList.advance(), we propose new tokens to be added into this "candidate list" that will + # advance us in fulfilling the constraints. + + # 2. Selecting best candidates such that we end up with highest probable candidates + # that fulfill our constraints. + + orig_len = sent_beam_indices.size(0) + device = sent_beam_indices.device + + # initialize states + topk_contraint_states = self.make_constraint_states(orig_len) + advance_constraint_states = self.make_constraint_states(orig_len) + + sidx, eidx = batch_idx * orig_len, (batch_idx + 1) * orig_len + this_batch_input_ids = input_ids[sidx:eidx] + this_batch_token_scores = vocab_scores[sidx:eidx] + full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1) + + # need to make new hypothesis that advance the constraints + track_new = { + "new_seqs": full_hypotheses.tolist(), + "new_states": [], + "new_indices": [], + "new_tokens": [], + "new_scores": [], + } + for seq_idx, pre_seq in enumerate(this_batch_input_ids): + # pre_seq = ith sequence generated before this step. + + # input_ids -> (topk) generic beam search best model next tokens + # -> (advance) constraints forcing the next token + # either way, we need to sort them into "banks" later, so store a "ConstraintListState" for all types of + # hypotheses. + + topk_state = topk_contraint_states[seq_idx] + topk_state.reset(full_hypotheses[seq_idx].cpu().tolist()) + + advance_state = advance_constraint_states[seq_idx] + advance_state.reset(pre_seq.cpu().tolist()) + + if not advance_state.completed: + advance_tokens = torch.LongTensor(advance_state.advance()).to(device) + for advance_token in advance_tokens: + # since adding each `advance_token` leads to a different hypothesis, create new state instance. + new_state = advance_state.copy(stateful=True) + new_state.add(advance_token.cpu().tolist()) + + advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist() + if advance_seq not in track_new["new_seqs"]: + # prevent duplicates, which are basically bound to happen in this process. + track_new["new_seqs"].append(advance_seq) + track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches + track_new["new_tokens"].append(advance_token) + track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token)) + track_new["new_states"].append(new_state) + elif push_progress: + # Basically, `sent_beam_indices` often chooses very little among `input_ids` the generated sequences that + # actually fulfill our constraints. For example, let constraints == ["loves pies"] and + + # pre_seq_1 = "The child loves pies and" pre_seq_2 = "The child plays in the playground and" + + # Without this step, if `sent_beam_indices` is something like [1,1], then + # 1. `pre_seq_1` won't be added to the list of (topk) hypothesis since it's not in the indices and + # 2. it won't be added to the list of (advance) hypothesis since it's completed already. (this is + # the else part of `if constraints_completed[seq_idx]`) + # 3. it ends up simply getting removed from consideration. + + # #3 might be fine and actually desired, since it's likely that it's a low-probability output anyways, + # especially if it's not in the list of `sent_beam_indices`. But this often leads to lengthened beam + # search times, since completed sequences keep getting removed after all this effort for constrained + # generation. + + # Here, we basically take `pre_seq_1` and to "push" it into the considered list of hypotheses, by simply + # appending the next likely token in the vocabulary and adding it to the list of hypotheses. + + new_score, new_token = torch.max(this_batch_token_scores[seq_idx], 0) # some next probable token + advance_seq = torch.cat((pre_seq, new_token.unsqueeze(0)), -1) + + advance_state = advance_constraint_states[seq_idx] + + advance_seq = advance_seq.cpu().tolist() + + advance_state.reset(advance_seq) + if advance_seq not in track_new["new_seqs"]: + # but still don't want to have duplicates + track_new["new_seqs"].append(advance_seq) + track_new["new_indices"].append(seq_idx) + track_new["new_tokens"].append(new_token) + track_new["new_scores"].append(new_score) + track_new["new_states"].append(advance_state) + + if len(track_new["new_indices"]) > 0: + new_indices = torch.tensor(track_new["new_indices"]).to(device) + new_tokens = torch.stack(track_new["new_tokens"]).to(device) + new_scores = torch.stack(track_new["new_scores"]).to(device) + + all_states = topk_contraint_states + track_new["new_states"] + all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1) + all_scores = torch.cat((sent_beam_scores, new_scores), -1) + all_banks = torch.tensor([one.get_bank() for one in all_states]).to(device) + + zipped = all_banks * 100 + all_scores + indices = zipped.sort(descending=True).indices + sorted_banks = all_banks[indices] + + # Then we end up with {sorted among bank C}, {sorted among bank C-1}, ..., {sorted among bank 0} + + counter = -1 + cur_bank = sorted_banks[0] + increments = [] + for bank in sorted_banks: + if bank == cur_bank: + counter += 1 + else: + counter = 0 + cur_bank = bank + increments.append(counter) + rearrangers = torch.tensor(np.argsort(increments, kind="mergesort")) + + indices = indices[rearrangers][:orig_len] + + sent_beam_scores = all_scores[indices] + sent_beam_tokens = all_tokens[indices] + sent_beam_indices = torch.cat((sent_beam_indices, new_indices))[indices] + + return sent_beam_scores, sent_beam_tokens, sent_beam_indices + + def finalize( + self, + input_ids: torch.LongTensor, + final_beam_scores: torch.FloatTensor, + final_beam_tokens: torch.LongTensor, + final_beam_indices: torch.LongTensor, + max_length: int, + pad_token_id: Optional[Union[int, torch.Tensor]] = None, + eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None, + beam_indices: Optional[torch.LongTensor] = None, + decoder_prompt_len: Optional[int] = 0, + ) -> Tuple[torch.LongTensor]: + batch_size = len(self._beam_hyps) + + if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id = torch.tensor(eos_token_id) + + # finalize all open beam hypotheses and add to generated hypotheses + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + continue + + # all open beam hypotheses are added to the beam hypothesis + # beam hypothesis class automatically keeps the best beams + + ids_collect = [] + for beam_id in range(self.num_beams): + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() + final_tokens = input_ids[batch_beam_idx] + + completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist()) + if completes_constraint: + beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None + generated_len = final_tokens.shape[-1] - decoder_prompt_len + beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len) + ids_collect.append(beam_id) + + # due to overly complex constraints or other factors, sometimes we can't gaurantee a successful + # generation. In these cases we simply return the highest scoring outputs. + if len(ids_collect) < self.num_beam_hyps_to_keep: + for beam_id in range(self.num_beams): + if beam_id not in ids_collect: + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() + final_tokens = input_ids[batch_beam_idx] + generated_len = final_tokens.shape[-1] - decoder_prompt_len + beam_hyp.add(final_tokens, final_score, generated_len=generated_len) + if len(ids_collect) >= self.num_beam_hyps_to_keep: + break + + # select the best hypotheses + sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) + best = [] + best_indices = [] + best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) + + # retrieve best hypotheses + for i, beam_hyp in enumerate(self._beam_hyps): + sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) + for j in range(self.num_beam_hyps_to_keep): + best_hyp_tuple = sorted_hyps.pop() + best_score = best_hyp_tuple[0] + best_hyp = best_hyp_tuple[1] + best_index = best_hyp_tuple[2] + sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) + + # append to lists + best.append(best_hyp) + + # append indices to list + best_indices.append(best_index) + + best_scores[i * self.num_beam_hyps_to_keep + j] = best_score + + # prepare for adding eos + sent_lengths_max = sent_lengths.max().item() + 1 + + sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max + decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) + + if len(best_indices) > 0 and best_indices[0] is not None: + indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) + else: + indices = None + + # shorter batches are padded if needed + if sent_lengths.min().item() != sent_lengths.max().item(): + if pad_token_id is None: + raise ValueError("`pad_token_id` has to be defined") + decoded.fill_(pad_token_id) + + if indices is not None: + indices.fill_(-1) + + # fill with hypotheses and eos_token_id if the latter fits in + for i, (hypo, best_idx) in enumerate(zip(best, best_indices)): + decoded[i, : sent_lengths[i]] = hypo + + if indices is not None: + indices[i, : len(best_idx)] = torch.tensor(best_idx) + + if sent_lengths[i] < sent_max_len: + # inserting only the first eos_token_id + decoded[i, sent_lengths[i]] = eos_token_id[0] + + return UserDict( + { + "sequences": decoded, + "sequence_scores": best_scores, + "beam_indices": indices, + } + ) + + +class BeamHypotheses: + def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None): + """ + Initialize n-best list of hypotheses. + """ + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.max_length = max_length + self.num_beams = num_beams + self.beams = [] + self.worst_score = 1e9 + + if not isinstance(self.early_stopping, bool) and self.max_length is None: + raise ValueError( + "When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the" + " BeamScorer class instance at initialization time." + ) + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.beams) + + def add( + self, + hyp: torch.LongTensor, + sum_logprobs: float, + beam_indices: Optional[torch.LongTensor] = None, + generated_len: Optional[int] = None, + ): + """ + Add a new hypothesis to the list. + """ + if generated_len is not None: + score = sum_logprobs / (generated_len**self.length_penalty) + # This 'else' case exists for retrocompatibility + else: + score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) + + if len(self) < self.num_beams or score > self.worst_score: + self.beams.append((score, hyp, beam_indices)) + if len(self) > self.num_beams: + sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)]) + del self.beams[sorted_next_scores[0][1]] + self.worst_score = sorted_next_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool: + """ + If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst + one in the heap, then we are done with this sentence. + """ + + if len(self) < self.num_beams: + return False + + # `True`: stop as soon as at least `num_beams` hypotheses are finished + if self.early_stopping is True: + return True + # `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate + # when `length_penalty` is positive. See the discussion below for more details. + # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 + elif self.early_stopping is False: + highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty + ret = self.worst_score >= highest_attainable_score + return ret + # `"never"`: compute the best possible score, depending on the signal of `length_penalty` + else: + # `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min + # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain + # its max this way + if self.length_penalty > 0.0: + if self.max_length <= decoder_prompt_len: + raise ValueError("max_length is not larger than decoder prompt length") + highest_attainable_score = ( + best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty + ) + # the opposite logic applies here (max `highest_attainable_score` from `cur_len`) + else: + highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty + ret = self.worst_score >= highest_attainable_score + return ret diff --git a/indextts/gpt/transformers_generation_utils.py b/indextts/gpt/transformers_generation_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..4e71b0bec8c70a193132e7919734db6b09bca137 --- /dev/null +++ b/indextts/gpt/transformers_generation_utils.py @@ -0,0 +1,4747 @@ +# coding=utf-8 +# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import inspect +import warnings +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch import nn +from torch.nn import functional as F + +from transformers.cache_utils import ( + Cache, + DynamicCache, + EncoderDecoderCache, + OffloadedCache, + QuantizedCacheConfig, + StaticCache, +) +from transformers.configuration_utils import PretrainedConfig +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.integrations.fsdp import is_fsdp_managed_module +from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput +from transformers.pytorch_utils import isin_mps_friendly +from transformers.tokenization_utils import ExtensionsTrie +from transformers.utils import ( + ModelOutput, + is_accelerate_available, + is_hqq_available, + is_optimum_quanto_available, + # is_quanto_available, + is_torchdynamo_compiling, + logging, +) +from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint +from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer +from transformers.generation.candidate_generator import ( + AssistedCandidateGenerator, + AssistedCandidateGeneratorDifferentTokenizers, + CandidateGenerator, + PromptLookupCandidateGenerator, + _crop_past_key_values, + _prepare_attention_mask, + _prepare_token_type_ids, +) +from transformers.generation.configuration_utils import ( + NEED_SETUP_CACHE_CLASSES_MAPPING, + QUANT_BACKEND_CLASSES_MAPPING, + GenerationConfig, + GenerationMode, +) +from transformers.generation.logits_process import ( + EncoderNoRepeatNGramLogitsProcessor, + EncoderRepetitionPenaltyLogitsProcessor, + EpsilonLogitsWarper, + EtaLogitsWarper, + ExponentialDecayLengthPenalty, + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, + HammingDiversityLogitsProcessor, + InfNanRemoveLogitsProcessor, + LogitNormalization, + LogitsProcessorList, + MinLengthLogitsProcessor, + MinNewTokensLengthLogitsProcessor, + MinPLogitsWarper, + NoBadWordsLogitsProcessor, + NoRepeatNGramLogitsProcessor, + PrefixConstrainedLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + SequenceBiasLogitsProcessor, + SuppressTokensAtBeginLogitsProcessor, + SuppressTokensLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TypicalLogitsWarper, + UnbatchedClassifierFreeGuidanceLogitsProcessor, +) +from transformers.generation.stopping_criteria import ( + ConfidenceCriteria, + EosTokenCriteria, + MaxLengthCriteria, + MaxTimeCriteria, + StoppingCriteria, + StoppingCriteriaList, + StopStringCriteria, +) + + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + from transformers.tokenization_utils_base import PreTrainedTokenizerBase + from transformers.generation.streamers import BaseStreamer + +logger = logging.get_logger(__name__) + +if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + + +@dataclass +class GenerateDecoderOnlyOutput(ModelOutput): + """ + Outputs of decoder-only generation models, when using non-beam methods. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + logits: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None + + +@dataclass +class GenerateEncoderDecoderOutput(ModelOutput): + """ + Outputs of encoder-decoder generation models, when using non-beam methods. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. + """ + + sequences: torch.LongTensor = None + scores: Optional[Tuple[torch.FloatTensor]] = None + logits: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None + + +@dataclass +class GenerateBeamDecoderOnlyOutput(ModelOutput): + """ + Outputs of decoder-only generation models, when using beam methods. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): + Final beam scores of the generated `sequences`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): + Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting + of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. + Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, sequence_length)`. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. + """ + + sequences: torch.LongTensor = None + sequences_scores: Optional[torch.FloatTensor] = None + scores: Optional[Tuple[torch.FloatTensor]] = None + logits: Optional[Tuple[torch.FloatTensor]] = None + beam_indices: Optional[torch.LongTensor] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None + + +@dataclass +class GenerateBeamEncoderDecoderOutput(ModelOutput): + """ + Outputs of encoder-decoder generation models, when using beam methods. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): + Final beam scores of the generated `sequences`. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): + Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting + of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. + Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, sequence_length)`. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, + sequence_length, sequence_length)`. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length, + sequence_length)`. + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. + """ + + sequences: torch.LongTensor = None + sequences_scores: Optional[torch.FloatTensor] = None + scores: Optional[Tuple[torch.FloatTensor]] = None + logits: Optional[Tuple[torch.FloatTensor]] = None + beam_indices: Optional[torch.LongTensor] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None + + +# TODO (joao): remove the equivalent classes and typing shortcuts below in v5 +# Equivalent classes (kept for retrocompatibility purposes) +GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput +ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput +SampleDecoderOnlyOutput = GenerateDecoderOnlyOutput + +ContrastiveSearchEncoderDecoderOutput = GenerateEncoderDecoderOutput +GreedySearchEncoderDecoderOutput = GenerateEncoderDecoderOutput +SampleEncoderDecoderOutput = GenerateEncoderDecoderOutput + +BeamSearchDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput +BeamSampleDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput + +BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput +BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput + +GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] +SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] +BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] +BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] +ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput] + +# Typing shortcuts +GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] +GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput] +GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput] + + +class GenerationMixin: + """ + A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. + + The class exposes [`~generation.GenerationMixin.generate`], which can be used for: + - *greedy decoding* if `num_beams=1` and `do_sample=False` + - *contrastive search* if `penalty_alpha>0` and `top_k>1` + - *multinomial sampling* if `num_beams=1` and `do_sample=True` + - *beam-search decoding* if `num_beams>1` and `do_sample=False` + - *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True` + - *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1` + - *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None` + - *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()` + + To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). + """ + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + """ + Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or + slicing inputs given the existing cache. + + See the forward pass in the model documentation for expected arguments (different models might have different + requirements for e.g. `past_key_values`). This function should work as is for most LLMs. + """ + + # 1. Handle BC: + model_inputs = {} + # - some models don't have `Cache` support (which implies they don't expect `cache_position` in `forward`) + if self._supports_cache_class: + model_inputs["cache_position"] = cache_position + # - `cache_position` was not a mandatory input in `prepare_inputs_for_generation` for those models, and this + # function may be called outside of `generate`. Handle most use cases by creating `cache_position` on the fly + # (this alternative is not as robust as calling `generate` and letting it create `cache_position`) + elif cache_position is None: + past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + + # 2. Generic cache-dependent input preparation + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case + if past_key_values is not None: + model_inputs["past_key_values"] = past_key_values + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + # 3. Prepare base model inputs + input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if not self.config.is_encoder_decoder: + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs[input_ids_key] = None + model_inputs["inputs_embeds"] = inputs_embeds + else: + # `clone` calls in this function ensure a consistent stride. See #32227 + model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) + model_inputs["inputs_embeds"] = None + else: + model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) + + # 4. Create missing `position_ids` on the fly + if ( + attention_mask is not None + and kwargs.get("position_ids") is None + and "position_ids" in set(inspect.signature(self.forward).parameters.keys()) + ): + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below) + + # 5. Slice model inputs if it's an input that should have the same length as `input_ids` + for model_input_name in ["position_ids", "token_type_ids"]: + model_input = kwargs.get(model_input_name) + if model_input is not None: + if past_key_values: + model_input = model_input[:, -input_ids.shape[1] :] + model_input = model_input.clone(memory_format=torch.contiguous_format) + model_inputs[model_input_name] = model_input + + # 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass) + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs[input_ids_key].shape + device = model_inputs[input_ids_key].device + + # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create + # the 4D causal mask exists, it should be present in the base model (XXXModel class). + base_model = getattr(self, self.base_model_prefix, None) + if base_model is None: + causal_mask_creation_function = getattr( + self, "_prepare_4d_causal_attention_mask_with_cache_position", None + ) + else: + causal_mask_creation_function = getattr( + base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None + ) + if causal_mask_creation_function is None: + logger.warning_once( + f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method " + "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're " + "writing code, see Llama for an example implementation. If you're a user, please report this " + "issue on GitHub." + ) + else: + attention_mask = causal_mask_creation_function( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + if attention_mask is not None: + model_inputs["attention_mask"] = attention_mask + + # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) + model_inputs.pop("labels", None) + return model_inputs + + def _prepare_model_inputs( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[torch.Tensor] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: + """ + This function extracts the model-specific `inputs` for generation. + """ + # 1. retrieve all kwargs that are non-None or non-model input related. + # some encoder-decoder models have different names for model and encoder + if ( + self.config.is_encoder_decoder + and hasattr(self, "encoder") + and self.encoder.main_input_name != self.main_input_name + ): + input_name = self.encoder.main_input_name + else: + input_name = self.main_input_name + + model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} + + # 2. check whether model_input_name is passed as kwarg + # if yes and `inputs` is None use kwarg inputs + inputs_kwarg = model_kwargs.pop(input_name, None) + if inputs_kwarg is not None and inputs is not None: + raise ValueError( + f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. " + f"Make sure to either pass {inputs} or {input_name}=..." + ) + elif inputs_kwarg is not None: + inputs = inputs_kwarg + + # 3. In the presence of `inputs_embeds` for text models: + # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model + # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with + # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`) + # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and + # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states. + if input_name == "input_ids" and "inputs_embeds" in model_kwargs: + if not self.config.is_encoder_decoder: + has_inputs_embeds_forwarding = "inputs_embeds" in set( + inspect.signature(self.prepare_inputs_for_generation).parameters.keys() + ) + if not has_inputs_embeds_forwarding: + raise ValueError( + f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} " + "doesn't have its forwarding implemented. See the GPT2 implementation for an example " + "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!" + ) + # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of + # the attention mask) can rely on the actual model input. + model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( + inputs, bos_token_id, model_kwargs=model_kwargs + ) + else: + if inputs is not None: + raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" + + # 4. if `inputs` is still None, try to create `input_ids` from BOS token + inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) + return inputs, input_name, model_kwargs + + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[torch.Tensor] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + encoder_outputs = model_kwargs.get("encoder_outputs") + if self.config.is_encoder_decoder and encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs.last_hidden_state.size()[:-1] + return torch.ones(shape, dtype=torch.long, device=self.device) * -100 + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, torch.Tensor): + batch_size = value.shape[0] + break + + if "inputs_embeds" in model_kwargs: + return torch.ones((batch_size, 0), dtype=torch.long, device=self.device) + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _prepare_attention_mask_for_generation( + self, + inputs: torch.Tensor, + pad_token_id: Optional[torch.Tensor], + eos_token_id: Optional[torch.Tensor], + ) -> torch.LongTensor: + # No information for attention mask inference -> return default attention mask + default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) + if pad_token_id is None: + return default_attention_mask + + is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] + if not is_input_ids: + return default_attention_mask + + is_pad_token_in_inputs = (pad_token_id is not None) and ( + isin_mps_friendly(elements=inputs, test_elements=pad_token_id).any() + ) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~( + isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any() + ) + can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id + attention_mask_from_padding = inputs.ne(pad_token_id).long() + + attention_mask = ( + attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask + ) + return attention_mask + + def _prepare_encoder_decoder_kwargs_for_generation( + self, + inputs_tensor: torch.Tensor, + model_kwargs, + model_input_name: Optional[str], + generation_config: GenerationConfig, + ) -> Dict[str, Any]: + # 1. get encoder + encoder = self.get_encoder() + # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device + # as the inputs. + if hasattr(self, "hf_device_map"): + if hasattr(encoder, "_hf_hook"): + encoder._hf_hook.io_same_device = True + else: + add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True)) + + # 2. Prepare encoder args and encoder kwargs from model kwargs and generation config. + irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not any(argument.startswith(p) for p in irrelevant_prefix) + } + encoder_signature = set(inspect.signature(encoder.forward).parameters) + encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature + if not encoder_accepts_wildcard: + encoder_kwargs = { + argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature + } + encoder_kwargs["output_attentions"] = generation_config.output_attentions + encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states + + # 3. make sure that encoder returns `ModelOutput` + model_input_name = model_input_name if model_input_name is not None else self.main_input_name + encoder_kwargs["return_dict"] = True + encoder_kwargs[model_input_name] = inputs_tensor + model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) # type: ignore + + return model_kwargs + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: Dict[str, torch.Tensor], + decoder_start_token_id: torch.Tensor, + device: torch.device = None, + ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") + else: + decoder_input_ids = None + + # 2. `decoder_start_token_id` must have shape (batch_size, 1) + if device is None: + device = self.device + if decoder_start_token_id.ndim == 1: + if decoder_start_token_id.shape[0] != batch_size: + raise ValueError( + f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}" + ) + decoder_start_token_id = decoder_start_token_id.view(-1, 1) + else: + decoder_start_token_id = ( + torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id + ) + + # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_start_token_id + # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the + # original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic. + # See: https://github.com/huggingface/transformers/pull/31470 + elif "donut" in self.__class__.__name__.lower() or ( + self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower() + ): + pass + elif self.config.model_type in ["whisper"]: + pass + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item(): + decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask + + return decoder_input_ids, model_kwargs + + @staticmethod + def _expand_inputs_for_generation( + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" + # Do not call torch.repeat_interleave if expand_size is 1 because it clones + # the input tensor and thus requires more memory although no change is applied + if expand_size == 1: + return input_ids, model_kwargs + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + def _extract_past_from_model_output(self, outputs: ModelOutput): + past_key_values = None + cache_name = "past_key_values" + if "past_key_values" in outputs: + past_key_values = outputs.past_key_values + elif "mems" in outputs: + past_key_values = outputs.mems + elif "past_buckets_states" in outputs: + past_key_values = outputs.past_buckets_states + elif "cache_params" in outputs: + past_key_values = outputs.cache_params + cache_name = "cache_params" + + return cache_name, past_key_values + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + # update past_key_values keeping its naming used in model code + cache_name, cache = self._extract_past_from_model_output(outputs) + model_kwargs[cache_name] = cache + if getattr(outputs, "state", None) is not None: + model_kwargs["state"] = outputs.state + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + if not is_encoder_decoder: + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + else: + # update decoder attention mask + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + model_kwargs["decoder_attention_mask"] = torch.cat( + [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], + dim=-1, + ) + + if model_kwargs.get("use_cache", True): + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens + else: + past_positions = model_kwargs.pop("cache_position") + new_positions = torch.arange( + past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype + ).to(past_positions.device) + model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) + return model_kwargs + + def _reorder_cache(self, past_key_values, beam_idx): + raise NotImplementedError( + f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to" + f" enable beam search for {self.__class__}" + ) + + def _get_candidate_generator( + self, + generation_config: GenerationConfig, + input_ids: torch.LongTensor, + inputs_tensor: torch.Tensor, + assistant_model: "PreTrainedModel", + logits_processor: LogitsProcessorList, + target_tokenizer: "PreTrainedTokenizerBase", + assistant_tokenizer: "PreTrainedTokenizerBase", + model_kwargs: Dict, + ) -> CandidateGenerator: + """ + Returns the candidate generator to be used in `assisted_generation` + """ + different_tokenizers = all(v is not None for v in (assistant_model, target_tokenizer, assistant_tokenizer)) + + if generation_config.prompt_lookup_num_tokens is not None: + candidate_generator = PromptLookupCandidateGenerator( + eos_token_id=generation_config._eos_token_tensor, + num_output_tokens=generation_config.prompt_lookup_num_tokens, + max_matching_ngram_size=generation_config.max_matching_ngram_size, + max_length=generation_config.max_length, + ) + elif different_tokenizers: + candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + target_tokenizer=target_tokenizer, + assistant_tokenizer=assistant_tokenizer, + ) + else: + candidate_generator = AssistedCandidateGenerator( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + ) + return candidate_generator + + def _get_logits_processor( + self, + generation_config: GenerationConfig, + input_ids_seq_length: int, + encoder_input_ids: torch.LongTensor, + prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], + logits_processor: Optional[LogitsProcessorList], + device: str = None, + model_kwargs: Optional[Dict[str, Any]] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + ) -> LogitsProcessorList: + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] + instances used to modify the scores of the language model head. + """ + # instantiate processors list + processors = LogitsProcessorList() + + if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1: + processors.append( + UnbatchedClassifierFreeGuidanceLogitsProcessor( + generation_config.guidance_scale, + self, + unconditional_ids=negative_prompt_ids, + unconditional_attention_mask=negative_prompt_attention_mask, + use_cache=generation_config.use_cache, + ) + ) + if generation_config.sequence_bias is not None: + processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias)) + + if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0: + processors.append( + HammingDiversityLogitsProcessor( + diversity_penalty=generation_config.diversity_penalty, + num_beams=generation_config.num_beams, + num_beam_groups=generation_config.num_beam_groups, + ) + ) + if ( + generation_config.encoder_repetition_penalty is not None + and generation_config.encoder_repetition_penalty != 1.0 + ): + if len(encoder_input_ids.shape) == 2: + processors.append( + EncoderRepetitionPenaltyLogitsProcessor( + penalty=generation_config.encoder_repetition_penalty, + encoder_input_ids=encoder_input_ids, + ) + ) + else: + warnings.warn( + "Passing `encoder_repetition_penalty` requires some form of `input_ids` to be passed to " + "`generate`, ignoring the argument.", + UserWarning, + ) + if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: + processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty)) + if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0: + processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size)) + if ( + generation_config.encoder_no_repeat_ngram_size is not None + and generation_config.encoder_no_repeat_ngram_size > 0 + ): + if len(encoder_input_ids.shape) == 2: + processors.append( + EncoderNoRepeatNGramLogitsProcessor( + generation_config.encoder_no_repeat_ngram_size, + encoder_input_ids, + ) + ) + else: + warnings.warn( + "Passing `encoder_no_repeat_ngram_size` requires some form of `input_ids` to be passed to " + "`generate`, ignoring the argument.", + UserWarning, + ) + if generation_config.bad_words_ids is not None: + processors.append( + NoBadWordsLogitsProcessor( + generation_config.bad_words_ids, + generation_config._eos_token_tensor, + ) + ) + if ( + generation_config.min_length is not None + and generation_config._eos_token_tensor is not None + and generation_config.min_length > 0 + ): + processors.append( + MinLengthLogitsProcessor( + generation_config.min_length, + generation_config._eos_token_tensor, + device=device, + ) + ) + if ( + generation_config.min_new_tokens is not None + and generation_config._eos_token_tensor is not None + and generation_config.min_new_tokens > 0 + ): + processors.append( + MinNewTokensLengthLogitsProcessor( + input_ids_seq_length, + generation_config.min_new_tokens, + generation_config._eos_token_tensor, + device=device, + ) + ) + if prefix_allowed_tokens_fn is not None: + processors.append( + PrefixConstrainedLogitsProcessor( + prefix_allowed_tokens_fn, + generation_config.num_beams // generation_config.num_beam_groups, + ) + ) + if generation_config.forced_bos_token_id is not None: + processors.append( + ForcedBOSTokenLogitsProcessor( + generation_config.forced_bos_token_id, + ) + ) + if generation_config.forced_eos_token_id is not None: + processors.append( + ForcedEOSTokenLogitsProcessor( + generation_config.max_length, + generation_config.forced_eos_token_id, + device=device, + ) + ) + if generation_config.remove_invalid_values is True: + processors.append(InfNanRemoveLogitsProcessor()) + if generation_config.exponential_decay_length_penalty is not None: + processors.append( + ExponentialDecayLengthPenalty( + generation_config.exponential_decay_length_penalty, + generation_config._eos_token_tensor, + input_ids_seq_length, + ) + ) + if generation_config.suppress_tokens is not None: + processors.append( + SuppressTokensLogitsProcessor( + generation_config.suppress_tokens, + device=device, + ) + ) + if generation_config.begin_suppress_tokens is not None: + begin_index = input_ids_seq_length + begin_index = ( + begin_index + if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) + else begin_index + 1 + ) + processors.append( + SuppressTokensAtBeginLogitsProcessor( + generation_config.begin_suppress_tokens, + begin_index, + device=device, + ) + ) + if generation_config.forced_decoder_ids is not None: + # TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT + raise ValueError( + "You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument " + "in favour of `input_ids` or `decoder_input_ids` respectively.", + ) + if generation_config.watermarking_config is not None: + processors.append( + generation_config.watermarking_config.construct_processor(self.config.vocab_size, device) + ) + + # TODO (joao): find a strategy to specify the order of the processors + processors = self._merge_criteria_processor_list(processors, logits_processor) + + # Processors previously known as `LogitsWarpers`, only applied with sampling strategies + if generation_config.do_sample: + # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a + # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) + if generation_config.num_beams > 1: + if isinstance(generation_config._eos_token_tensor, list): + min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 + elif isinstance(generation_config._eos_token_tensor, torch.Tensor): + min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 + else: + min_tokens_to_keep = 2 + else: + min_tokens_to_keep = 1 + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if generation_config.temperature is not None and generation_config.temperature != 1.0: + processors.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + processors.append( + TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + processors.append( + TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.min_p is not None: + # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) + processors.append( + MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.typical_p is not None and generation_config.typical_p < 1.0: + processors.append( + TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: + processors.append( + EpsilonLogitsWarper( + epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep + ) + ) + if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: + processors.append( + EtaLogitsWarper( + epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device + ) + ) + + # `LogitNormalization` should always be the last logit processor, when present + if generation_config.renormalize_logits is True: + processors.append(LogitNormalization()) + return processors + + def _get_stopping_criteria( + self, + generation_config: GenerationConfig, + stopping_criteria: Optional[StoppingCriteriaList], + tokenizer: Optional["PreTrainedTokenizerBase"] = None, + **kwargs, + ) -> StoppingCriteriaList: + criteria = StoppingCriteriaList() + if generation_config.max_length is not None: + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + criteria.append( + MaxLengthCriteria( + max_length=generation_config.max_length, + max_position_embeddings=max_position_embeddings, + ) + ) + if generation_config.max_time is not None: + criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) + if generation_config.stop_strings is not None: + if tokenizer is None: + raise ValueError( + "There are one or more stop strings, either in the arguments to `generate` or in the " + "model's generation config, but we could not locate a tokenizer. When generating with " + "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`." + ) + criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer)) + if generation_config._eos_token_tensor is not None: + criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor)) + if ( + generation_config.is_assistant + and generation_config.assistant_confidence_threshold is not None + and generation_config.assistant_confidence_threshold > 0 + ): + criteria.append( + ConfidenceCriteria(assistant_confidence_threshold=generation_config.assistant_confidence_threshold) + ) + criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) + return criteria + + def _merge_criteria_processor_list( + self, + default_list: Union[LogitsProcessorList, StoppingCriteriaList], + custom_list: Union[LogitsProcessorList, StoppingCriteriaList], + ) -> Union[LogitsProcessorList, StoppingCriteriaList]: + if len(custom_list) == 0: + return default_list + for default in default_list: + for custom in custom_list: + if type(custom) is type(default): + object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" + raise ValueError( + f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" + f" `.generate()`, but it has already been created with the values {default}. {default} has been" + " created by passing the corresponding arguments to generate or by the model's config default" + f" values. If you just want to change the default values of {object_type} consider passing" + f" them as arguments to `.generate()` instead of using a custom {object_type}." + ) + default_list.extend(custom_list) + return default_list + + def compute_transition_scores( + self, + sequences: torch.Tensor, + scores: Tuple[torch.Tensor], + beam_indices: Optional[torch.Tensor] = None, + normalize_logits: bool = False, + ) -> torch.Tensor: + """ + Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was + used). This is a convenient method to quicky obtain the scores of the selected tokens at generation time. + + Parameters: + sequences (`torch.LongTensor`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or + shorter if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)`): + Transition scores for each vocabulary token at each generation step. Beam transition scores consisting + of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. + Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), + with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. + beam_indices (`torch.LongTensor`, *optional*): + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at + generate-time. + normalize_logits (`bool`, *optional*, defaults to `False`): + Whether to normalize the logits (which, for legacy reasons, may be unnormalized). + + Return: + `torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing + the transition scores (logits) + + Examples: + + ```python + >>> from transformers import GPT2Tokenizer, AutoModelForCausalLM + >>> import numpy as np + + >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer.pad_token_id = tokenizer.eos_token_id + >>> inputs = tokenizer(["Today is"], return_tensors="pt") + + >>> # Example 1: Print the scores for each token generated with Greedy Search + >>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True) + >>> transition_scores = model.compute_transition_scores( + ... outputs.sequences, outputs.scores, normalize_logits=True + ... ) + >>> # input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for + >>> # encoder-decoder models, like BART or T5. + >>> input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1] + >>> generated_tokens = outputs.sequences[:, input_length:] + >>> for tok, score in zip(generated_tokens[0], transition_scores[0]): + ... # | token | token string | log probability | probability + ... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}") + | 262 | the | -1.414 | 24.33% + | 1110 | day | -2.609 | 7.36% + | 618 | when | -2.010 | 13.40% + | 356 | we | -1.859 | 15.58% + | 460 | can | -2.508 | 8.14% + + >>> # Example 2: Reconstruct the sequence scores from Beam Search + >>> outputs = model.generate( + ... **inputs, + ... max_new_tokens=5, + ... num_beams=4, + ... num_return_sequences=4, + ... return_dict_in_generate=True, + ... output_scores=True, + ... ) + >>> transition_scores = model.compute_transition_scores( + ... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False + ... ) + >>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores. + >>> # Tip 1: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the + >>> # use case, you might want to recompute it with `normalize_logits=True`. + >>> # Tip 2: the output length does NOT include the input length + >>> output_length = np.sum(transition_scores.numpy() < 0, axis=1) + >>> length_penalty = model.generation_config.length_penalty + >>> reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty) + >>> print(np.allclose(outputs.sequences_scores, reconstructed_scores)) + True + ```""" + # 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent + # to a beam search approach were the first (and only) beam is always selected + if beam_indices is None: + beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device) + beam_indices = beam_indices.expand(-1, len(scores)) + + # 2. reshape scores as [batch_size*vocab_size, # generation steps] with # generation steps being + # seq_len - input_length + scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1) + + # 3. Optionally normalize the logits (across the vocab dimension) + if normalize_logits: + scores = scores.reshape(-1, self.config.vocab_size, scores.shape[-1]) + scores = torch.nn.functional.log_softmax(scores, dim=1) + scores = scores.reshape(-1, scores.shape[-1]) + + # 4. cut beam_indices to longest beam length + beam_indices_mask = beam_indices < 0 + max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max() + beam_indices = beam_indices.clone()[:, :max_beam_length] + beam_indices_mask = beam_indices_mask[:, :max_beam_length] + + # 5. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards + beam_indices[beam_indices_mask] = 0 + + # 6. multiply beam_indices with vocab size to gather correctly from scores + beam_sequence_indices = beam_indices * self.config.vocab_size + + # 7. Define which indices contributed to scores + cut_idx = sequences.shape[-1] - max_beam_length + indices = sequences[:, cut_idx:] + beam_sequence_indices + + # 8. Compute scores + transition_scores = scores.gather(0, indices) + + # 9. Mask out transition_scores of beams that stopped early + transition_scores[beam_indices_mask] = 0 + + return transition_scores + + def _validate_model_class(self): + """ + Confirms that the model class is compatible with generation. If not, raises an exception that points to the + right class to use. + """ + # TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from + # `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can + # safely call `GenerationMixin.generate` + if not is_torchdynamo_compiling() and not self.can_generate(): + terminations_with_generation_support = [ + "ForCausalLM", + "ForConditionalGeneration", + "ForSpeechSeq2Seq", + "ForVision2Seq", + ] + raise TypeError( + f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " + "it doesn't have a language model head. Classes that support generation often end in one of these " + f"names: {terminations_with_generation_support}." + ) + + def _validate_assistant(self, assistant_model, tokenizer, assistant_tokenizer): + if assistant_model is None: + return + + if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder: + attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"] + attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check] + are_equal = all( + getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check + ) + if not are_equal: + raise ValueError( + "The main model and the assistant don't have compatible encoder-dependent input shapes. " + "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper." + ) + + doc_reference = ( + "(see https://huggingface.co/docs/transformers/en/generation_strategies#universal-assisted-decoding)" + ) + if self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size: + if assistant_tokenizer is not None: + raise ValueError( + f"`assistant_tokenizer` is not required when the main and assistant models use the same tokenizer. Please omit `assistant_tokenizer` from `generate()` {doc_reference}." + ) + else: + if tokenizer is None or assistant_tokenizer is None: + raise ValueError( + f"The main and assistant moedels have different tokenizers. Please provide `tokenizer` and `assistant_tokenizer` to `generate()` {doc_reference}." + ) + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + """Validates model kwargs for generation. Generate argument typos will also be caught here.""" + # If a `Cache` instance is passed, checks whether the model is compatible with it + if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: + raise ValueError( + f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " + "check the model documentation for supported cache formats." + ) + + # Excludes arguments that are handled before calling any model function + if self.config.is_encoder_decoder: + for key in ["decoder_input_ids"]: + model_kwargs.pop(key, None) + + unused_model_args = [] + model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) + # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If + # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) + if "kwargs" in model_args or "model_kwargs" in model_args: + model_args |= set(inspect.signature(self.forward).parameters) + + # Encoder-Decoder models may also need Encoder arguments from `model_kwargs` + if self.config.is_encoder_decoder: + base_model = getattr(self, self.base_model_prefix, None) + + # allow encoder kwargs + encoder = getattr(self, "encoder", None) + # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`. + # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder` + # TODO: A better way to handle this. + if encoder is None and base_model is not None: + encoder = getattr(base_model, "encoder", None) + + if encoder is not None: + encoder_model_args = set(inspect.signature(encoder.forward).parameters) + model_args |= encoder_model_args + + # allow decoder kwargs + decoder = getattr(self, "decoder", None) + if decoder is None and base_model is not None: + decoder = getattr(base_model, "decoder", None) + + if decoder is not None: + decoder_model_args = set(inspect.signature(decoder.forward).parameters) + model_args |= {f"decoder_{x}" for x in decoder_model_args} + + # allow assistant_encoder_outputs to be passed if we're doing assisted generating + if "assistant_encoder_outputs" in model_kwargs: + model_args |= {"assistant_encoder_outputs"} + + for key, value in model_kwargs.items(): + if value is not None and key not in model_args: + unused_model_args.append(key) + + if unused_model_args: + raise ValueError( + f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" + " generate arguments will also show up in this list)" + ) + + def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): + """Performs validation related to the resulting generated length""" + + # Can't throw warnings/exceptions during compilation + if is_torchdynamo_compiling(): + return + + # 1. Max length warnings related to poor parameterization + if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: + # 20 is the default max_length of the generation config + warnings.warn( + f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " + "generation length. We recommend setting `max_new_tokens` to control the maximum length of the " + "generation.", + UserWarning, + ) + if input_ids_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + raise ValueError( + f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_length` or, better yet, setting `max_new_tokens`." + ) + + # 2. Min length warnings due to unfeasible parameter combinations + min_length_error_suffix = ( + " Generation will stop at the defined maximum length. You should decrease the minimum length and/or " + "increase the maximum length." + ) + if has_default_max_length: + min_length_error_suffix += ( + f" Note that `max_length` is set to {generation_config.max_length}, its default value." + ) + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: + warnings.warn( + f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than" + f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, + UserWarning, + ) + if generation_config.min_new_tokens is not None: + min_length = generation_config.min_new_tokens + input_ids_length + if min_length > generation_config.max_length: + warnings.warn( + f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when " + f"added to the prompt length ({input_ids_length}), is larger than" + f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, + UserWarning, + ) + + def _prepare_generated_length( + self, + generation_config, + has_default_max_length, + has_default_min_length, + model_input_name, + input_ids_length, + inputs_tensor, + ): + """Prepared max and min length in generation configs to avoid clashes between similar attributes""" + + if generation_config.max_new_tokens is not None: + if not has_default_max_length and generation_config.max_length is not None: + logger.warning( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.max_length = generation_config.max_new_tokens + input_ids_length + + # if both `inputs_embeds` and `input_ids` are passed, we do not correct the length + # otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length`` + elif ( + model_input_name == "inputs_embeds" + and input_ids_length != inputs_tensor.shape[1] + and not self.config.is_encoder_decoder + ): + generation_config.max_length -= inputs_tensor.shape[1] + + # same for min length + if generation_config.min_new_tokens is not None: + if not has_default_min_length: + logger.warning( + f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(=" + f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + generation_config.min_length = generation_config.min_new_tokens + input_ids_length + + elif ( + model_input_name == "inputs_embeds" + and input_ids_length != inputs_tensor.shape[1] + and not self.config.is_encoder_decoder + ): + generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0) + + return generation_config + + def _prepare_generation_config( + self, generation_config: Optional[GenerationConfig], **kwargs: Dict + ) -> Tuple[GenerationConfig, Dict]: + """ + Prepares the base generation config, then applies any generation configuration options from kwargs. This + function handles retrocompatibility with respect to configuration files. + """ + # TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400) + # replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with + # the parameterization in `fullgraph=False` so as to enable `fullgraph=True`. + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + using_model_generation_config = False + if generation_config is None: + # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, + # the following conditions must be met + # 1) the generation config must have been created from the model config (`_from_model_config` field); + # 2) the generation config must have seen no modification since its creation (the hash is the same); + # 3) there are non-default generation parameters in the model config. + # 4) the user must have set new generation parameters in the model config. + # NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation. + if ( + not is_torchdynamo_compiling() + and self.generation_config._from_model_config # 1) + and self.generation_config._original_object_hash == hash(self.generation_config) # 2) + and len(self.config._get_non_default_generation_parameters()) > 0 # 3) + ): + new_generation_config = GenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: # 4) + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed in v5." + " Please use and modify the model generation configuration (see" + " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )", + UserWarning, + ) + self.generation_config = new_generation_config + + generation_config = self.generation_config + using_model_generation_config = True + + # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` + # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an + # exception will be raised in `_validate_model_kwargs` + if not is_torchdynamo_compiling(): + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model + if not using_model_generation_config: + if generation_config.bos_token_id is None: + generation_config.bos_token_id = self.generation_config.bos_token_id + if generation_config.eos_token_id is None: + generation_config.eos_token_id = self.generation_config.eos_token_id + if generation_config.pad_token_id is None: + generation_config.pad_token_id = self.generation_config.pad_token_id + if generation_config.decoder_start_token_id is None: + generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id + else: + model_kwargs = kwargs + + return generation_config, model_kwargs + + def _get_initial_cache_position(self, input_ids, model_kwargs): + """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" + # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` + if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder: + cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder: + cache_position = ( + torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + ) + else: + cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 + + past_length = 0 + if model_kwargs.get("past_key_values") is not None: + cache = model_kwargs["past_key_values"] + past_length = 0 + if not isinstance(cache, Cache): + past_length = cache[0][0].shape[2] + elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: + past_length = cache.get_seq_length() + + # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty, + # end-to-end compilation will yield bad results because `cache_position` will be incorrect. + if not is_torchdynamo_compiling(): + cache_position = cache_position[past_length:] + + model_kwargs["cache_position"] = cache_position + return model_kwargs + + def _get_cache( + self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs + ) -> Cache: + """ + Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a + new `generate` call requires a larger cache or uses a different batch size. + + Returns the resulting cache object. + """ + cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + + if hasattr(self, "_cache"): + cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache + + if cache_implementation == "sliding_window": + max_cache_len = min(self.config.sliding_window, max_cache_len) + + need_new_cache = ( + not hasattr(self, "_cache") + or (not isinstance(cache_to_check, cache_cls)) + or cache_to_check.batch_size != batch_size + ) + if cache_implementation != "mamba": + need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len + + if requires_cross_attention_cache and hasattr(self, "_cache"): + need_new_cache = ( + need_new_cache + or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1] + ) + + if need_new_cache: + if hasattr(self.config, "_pre_quantization_dtype"): + cache_dtype = self.config._pre_quantization_dtype + else: + if not is_torchdynamo_compiling(): + cache_dtype = self.dtype + else: + # NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`. + # Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative + # models. May cause trobles with non-text modalities. + cache_dtype = self.get_output_embeddings().weight.dtype + + def get_layer_device_map(execution_device_map: Optional[dict] = None): + if execution_device_map is None: + return None + elif len(execution_device_map) == 1 and "" in execution_device_map: + return {idx: execution_device_map[""] for idx in range(self.config.num_hidden_layers)} + layer_device_map = {} + for layer in execution_device_map: + for idx in range(self.config.num_hidden_layers): + if f".{idx}." in f"{layer}.": + layer_device_map[idx] = execution_device_map[layer] + break + for idx in range(self.config.num_hidden_layers): + if idx not in layer_device_map: + raise RuntimeError(f"layer {idx} has not been mapped to a device.") + return layer_device_map + + execution_device_map = None + # Taken from dispatch_model from accelerate. + # This is needed here if we don't want to make changes in accelerate in order to save execution_device + # For offloaded case, we need to get the execution device, not just the device where it is offloaded + if hasattr(self, "hf_device_map"): + main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0] + execution_device_map = { + name: main_device if device in ["cpu", "disk"] else device + for name, device in self.hf_device_map.items() + } + layer_device_map = get_layer_device_map(execution_device_map) + + cache_kwargs = { + "config": self.config.get_text_config(), + "batch_size": batch_size, + "max_cache_len": max_cache_len, + "device": device, + "dtype": cache_dtype, + "layer_device_map": layer_device_map, + } + self._cache = cache_cls(**cache_kwargs) + if requires_cross_attention_cache: + encoder_kwargs = cache_kwargs.copy() + encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1] + self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs)) + else: + self._cache.reset() + return self._cache + + def _supports_default_dynamic_cache(self) -> bool: + """ + Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. + This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which + uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in + order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed + for `HybridMambaAttentionDynamicCache`). + """ + return ( + self._supports_cache_class + and "jamba" not in self.__class__.__name__.lower() + and "zamba" not in self.__class__.__name__.lower() + ) + + def _prepare_cache_for_generation( + self, + generation_config: GenerationConfig, + model_kwargs: Dict, + assistant_model: "PreTrainedModel", + batch_size: int, + max_cache_length: int, + device: torch.device, + ) -> bool: + """ + Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is + instantiated, writes it to `model_kwargs`, under the name expected by the model. + """ + + cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + + # Quick escape route 1: if the user specifies a cache, we only need to: + # a) check for conflicting `generate` arguments + # b) convert to the new cache format (if the user passes a legacy cache and model supports it) + user_defined_cache = model_kwargs.get(cache_name) + if user_defined_cache is not None: + if generation_config.cache_implementation is not None: + raise ValueError( + f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache(): + model_kwargs[cache_name] = ( + DynamicCache.from_legacy_cache(user_defined_cache) + if not requires_cross_attention_cache + else EncoderDecoderCache.from_legacy_cache(user_defined_cache) + ) + return + + # Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in + # `generation_config.validate()`) + if generation_config.use_cache is False: + return + + # Quick escape route 3: model that only supports legacy caches = nothing to prepare + if not self._supports_default_dynamic_cache(): + if generation_config.cache_implementation is not None: + warnings.warn( + "This model does not support `Cache` instances, it only supports the legacy cache format (tuple " + f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be " + "ignored.", + UserWarning, + ) + return + + # Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation` + + # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, + # which is only supported in dynamic caches atm + if assistant_model is not None and generation_config.cache_implementation is not None: + logger.warning_once( + "An assistant model is provided, using a dynamic cache instead of a cache of type=" + f"'{generation_config.cache_implementation}'." + ) + generation_config.cache_implementation = None + + if generation_config.cache_implementation is not None: + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static" and not self._supports_static_cache: + raise ValueError( + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" + ) + model_kwargs[cache_name] = self._get_cache( + cache_implementation=generation_config.cache_implementation, + batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size, + max_cache_len=max_cache_length, + device=device, + model_kwargs=model_kwargs, + ) + elif generation_config.cache_implementation == "quantized": + if not self._supports_quantized_cache: + raise ValueError( + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue and tag @zucchini-nlp." + ) + + cache_config = ( + generation_config.cache_config + if generation_config.cache_config is not None + else QuantizedCacheConfig() + ) + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + + # if cache_config.backend == "quanto" and not (is_optimum_quanto_available() or is_quanto_available()): + if cache_config.backend == "quanto" and not is_optimum_quanto_available(): + raise ImportError( + "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. " + "Please install it via with `pip install optimum-quanto`" + ) + elif cache_config.backend == "HQQ" and not is_hqq_available(): + raise ImportError( + "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " + "Please install it via with `pip install hqq`" + ) + + model_kwargs[cache_name] = cache_class(cache_config) + elif generation_config.cache_implementation == "offloaded": + model_kwargs[cache_name] = OffloadedCache() + + # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that + # keeps copying the cache thus using much more memory + else: + model_kwargs[cache_name] = ( + DynamicCache() + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(), DynamicCache()) + ) + + def _supports_num_logits_to_keep(self) -> bool: + """ + Return True if the current model supports the keyword argument `num_logits_to_keep` in forward() + to save memory. Checking it in this way allows to avoid using a new model attribute. + """ + return "num_logits_to_keep" in set(inspect.signature(self.forward).parameters.keys()) + + def _prepare_special_tokens( + self, + generation_config: GenerationConfig, + kwargs_has_attention_mask: Optional[bool] = None, + device: Optional[Union[torch.device, str]] = None, + ): + """ + Prepares the special tokens for generation, overwriting the generation config with their processed versions + converted to tensor. + + Note that `generation_config` is changed in place and stops being serializable after this method is called. + That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the + function). However, if called outside `generate`, consider creating a copy of `generation_config` first. + """ + + # Convert special tokens to tensors + def _tensor_or_none(token, device=None): + if token is None: + return token + + device = device if device is not None else self.device + if isinstance(token, torch.Tensor): + return token.to(device) + return torch.tensor(token, device=device, dtype=torch.long) + + bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) + eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) + pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) + decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device) + + # for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892) + if self.config.is_encoder_decoder: + decoder_start_token_tensor = ( + decoder_start_token_tensor if decoder_start_token_tensor is not None else bos_token_tensor + ) + + # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). + if eos_token_tensor is not None and eos_token_tensor.ndim == 0: + eos_token_tensor = eos_token_tensor.unsqueeze(0) + + # Set pad token if unset (and there are conditions to do so) + if pad_token_tensor is None and eos_token_tensor is not None: + if not is_torchdynamo_compiling(): + if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") + pad_token_tensor = eos_token_tensor[0] + + # Sanity checks/warnings + if self.config.is_encoder_decoder and decoder_start_token_tensor is None: + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow + if ( + eos_token_tensor is not None + and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any() + ): + if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: + logger.warning_once( + "The attention mask is not set and cannot be inferred from input because pad token is same as " + "eos token. As a consequence, you may observe unexpected behavior. Please pass your input's " + "`attention_mask` to obtain reliable results." + ) + if eos_token_tensor is not None and ( + torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any() + ): + logger.warning( + f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation " + "will not stop until the maximum length is reached. Depending on other flags, it may even crash." + ) + + # Update generation config with the updated special tokens tensors + # NOTE: this must be written into a different attribute name than the one holding the original special tokens + # (in their non-tensor form), in order to enable end-to-end compilation. See + # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations + generation_config._bos_token_tensor = bos_token_tensor + generation_config._eos_token_tensor = eos_token_tensor + generation_config._pad_token_tensor = pad_token_tensor + generation_config._decoder_start_token_tensor = decoder_start_token_tensor + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + r""" + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](../generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config ([`~generation.GenerationConfig`], *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which has the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complements the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. If your stopping criteria depends on the `scores` input, make + sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is + intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*): + Whether to continue running the while loop until max_length. Unless overridden, this flag will be set + to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid + deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`. + assistant_model (`PreTrainedModel`, *optional*): + An assistant model that can be used to accelerate generation. The assistant model must have the exact + same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model + is much faster than running generation with the model you're calling generate from. As such, the + assistant model should be much smaller. + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + The negative prompt needed for some processors such as CFG. The batch size must match the input batch + size. This is an experimental feature, subject to breaking API changes in future versions. + negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Attention_mask for `negative_prompt_ids`. + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GenerateDecoderOnlyOutput`], + - [`~generation.GenerateBeamDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + """ + + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria + assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation + + generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) + self._validate_model_kwargs(model_kwargs.copy()) + self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer) + + # 2. Set generation parameters if not already defined + if synced_gpus is None: + synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 + + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None + + # 3. Define model inputs + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + device = inputs_tensor.device + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) + + # decoder-only models must use left-padding for batched generation. + if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): + # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` + # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. + if ( + generation_config._pad_token_tensor is not None + and batch_size > 1 + and len(inputs_tensor.shape) == 2 + and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + # 4. Define other model kwargs + # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are + # generating the first new token or not, and we only want to use the embeddings for the first new token) + if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": + generation_config.use_cache = True + + if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor + ) + elif kwargs_has_attention_mask: + # TODO (joao): generalize this check with other types of inputs + if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2: + raise ValueError("`attention_mask` passed to `generate` must be 2D.") + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name, generation_config + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, + model_kwargs=model_kwargs, + decoder_start_token_id=generation_config._decoder_start_token_tensor, + device=inputs_tensor.device, + ) + else: + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + + if generation_config.token_healing: + input_ids = self.heal_tokens(input_ids, tokenizer) + + if streamer is not None: + streamer.put(input_ids.cpu()) + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) + + # If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole + # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding + # dynamically overrides this value as it can need more than the last token logits + if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs: + model_kwargs["num_logits_to_keep"] = 1 + + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # 7. Prepare the cache. + # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. + # - different models have a different cache name expected by the model (default = "past_key_values") + # - `max_length`, prepared above, is used to determine the maximum cache length + # TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format) + cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" + user_defined_cache = model_kwargs.get(cache_name) + max_cache_length = generation_config.max_length + if ( + inputs_tensor.shape[1] != input_ids_length + and model_input_name == "inputs_embeds" + and not self.config.is_encoder_decoder + ): + max_cache_length += inputs_tensor.shape[1] + self._prepare_cache_for_generation( + generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device + ) + + # 8. determine generation mode + generation_mode = generation_config.get_generation_mode(assistant_model) + + if streamer is not None and (generation_config.num_beams > 1): + raise ValueError( + "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." + ) + + if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + + # 9. prepare logits processors and stopping criteria + prepared_logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + device=inputs_tensor.device, + model_kwargs=model_kwargs, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + prepared_stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs + ) + + # Set model_kwargs `use_cache` so we can use it later in forward runs + model_kwargs["use_cache"] = generation_config.use_cache + + # 10. go into different generation modes + if generation_mode == GenerationMode.ASSISTED_GENERATION: + if generation_config.num_return_sequences > 1: + raise ValueError( + "num_return_sequences has to be 1 when doing assisted generate, " + f"but is {generation_config.num_return_sequences}." + ) + if batch_size > 1: + raise ValueError("assisted generate is only supported for batch_size = 1") + if not model_kwargs["use_cache"]: + raise ValueError("assisted generate requires `use_cache=True`") + if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]: + raise ValueError("assisted generate is not supported with Static cache classes`") + if self._is_stateful: + # In assisted generation we need the ability to confirm whether the model would pick certain tokens, + # which is not possible with stateful models (they can't reset to a previous subset of generated text) + raise ValueError( + f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" + ) + + # 11. Get the candidate generator, given the parameterization + candidate_generator = self._get_candidate_generator( + generation_config=generation_config, + input_ids=input_ids, + inputs_tensor=inputs_tensor, + assistant_model=assistant_model, + logits_processor=logits_processor, + target_tokenizer=tokenizer, + assistant_tokenizer=assistant_tokenizer, + model_kwargs=model_kwargs, + ) + + # 12. run assisted generate + result = self._assisted_decoding( + input_ids, + candidate_generator=candidate_generator, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + elif generation_mode == GenerationMode.DOLA_GENERATION: + if self._is_stateful: + # DoLa decoding was not designed for stateful models, and would require some changes + raise ValueError( + f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}" + ) + result = self._dola_decoding( + input_ids, + dola_layers=generation_config.dola_layers, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + + elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: + if not model_kwargs["use_cache"]: + raise ValueError("Contrastive search requires `use_cache=True`") + if self._is_stateful: + # Just like assisted generation, we need to be able to rollback to a previous state (see comment above) + raise ValueError( + f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" + ) + + result = self._contrastive_search( + input_ids, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + + elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): + # 11. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + result = self._sample( + input_ids, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + + elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + max_length=generation_config.max_length, + ) + + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run beam sample + result = self._beam_search( + input_ids, + beam_scorer, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH: + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, + max_length=generation_config.max_length, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + result = self._group_beam_search( + input_ids, + beam_scorer, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH: + final_constraints = [] + if generation_config.constraints is not None: + final_constraints = generation_config.constraints + + if generation_config.force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]` " + f"of positive integers, but is {generation_config.force_words_ids}." + ) + + if ( + not isinstance(generation_config.force_words_ids, list) + or len(generation_config.force_words_ids) == 0 + ): + typeerror() + + for word_ids in generation_config.force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any(not isinstance(token_ids, list) for token_ids in word_ids): + typeerror() + if any( + any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + + # 11. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=final_constraints, + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + max_length=generation_config.max_length, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + result = self._constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + # Convert to legacy cache format if requested + if ( + generation_config.return_legacy_cache is not False # Should check for `True` after v4.47 + and not is_torchdynamo_compiling() + and hasattr(result, "past_key_values") + and hasattr(result.past_key_values, "to_legacy_cache") + and result.past_key_values.to_legacy_cache is not None + ): + # handle BC (convert by default if he user hasn't passed a cache AND the cache is of the default type) + should_convert_cache = generation_config.return_legacy_cache + is_user_defined_cache = user_defined_cache is not None + is_default_cache_type = ( + type(result.past_key_values) == DynamicCache # noqa E721 + or ( + isinstance(result.past_key_values, EncoderDecoderCache) + and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721 + and type(result.past_key_values.cross_attention_cache) == DynamicCache # noqa E721 + ) + ) + if not is_user_defined_cache and is_default_cache_type: + logger.warning_once( + "From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` " + "instance instead by default (as opposed to the legacy tuple of tuples format). If you want to " + "keep returning the legacy format, please set `return_legacy_cache=True`." + ) + should_convert_cache = True + if should_convert_cache: + result.past_key_values = result.past_key_values.to_legacy_cache() + return result + + def _has_unfinished_sequences( + self, + this_peer_finished: bool, + synced_gpus: bool, + device: torch.device, + cur_len: Optional[int] = None, + max_length: Optional[int] = None, + ) -> bool: + """ + Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is + fed through `this_peer_finished`. ZeRO stage 3-friendly. + """ + # torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile, + # although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria) + # TODO (joao): remove this when torch's support for control flow is not experimental (https://pytorch.org/docs/stable/generated/torch.cond.html) + if is_torchdynamo_compiling(): + return cur_len < max_length + else: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + return False + elif this_peer_finished: + return False + return True + + def heal_tokens( + self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None + ) -> torch.LongTensor: + r""" + Generates sequences of token ids for models with a language modeling head. + Parameters: + input_ids (`torch.LongTensor`): The sequence used as a prompt for the generation. + tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids. + Return: + `torch.LongTensor` where each sequence has its tail token replaced with its appropriate extension. + """ + if tokenizer is None: + raise ValueError( + " When generating with token healing, you must pass the model's tokenizer to the `tokenizer` " + "argument of `generate`." + ) + + bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id + vocab_trie = ExtensionsTrie(tokenizer.get_vocab()) + generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id) + + # assumption: leading/trailing whitespace is not meaningful, so the prompts are + # stripped before re-tokenizing to desensitize generation to whitespace artefacts + prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)] + input_ids = tokenizer( + prompts, + return_tensors="pt", + padding=True, + ).input_ids.to(input_ids.device) + + # replace bos with pad to not condition healing on it + input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids) + + """ + the latter code assumes the input_ids is not empty, + input_id has to be checked if contains elements + """ + if input_ids.numel() == 0: + return input_ids + + tail_ids = input_ids[:, -1].tolist() + + space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0] + # tail tokens are used for a prefix search, thus, whitespaces are replaced with + # their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace + tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids) + + for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)): + batch_ids = input_ids[batch_idx] + if torch.all(batch_ids == pad_token_id).item(): + continue # skip empty sequences (all pad ids) + + # apply bias for alternatives (extensions) to the tail token + """ + seq_bias key has to be tuple with int so have to use + tokenizer function to convert str to int + """ + seq_bias = { + (tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok) + } + + if len(seq_bias) == 1: + continue # skip if there are no token alternatives to heal with + + # slightly favor original token to limit aggressive healing e.g. 'http' -> 'https' + seq_bias[(tail_id,)] += 1.0 + generation_config.update(sequence_bias=seq_bias) + + trimmed_ids = batch_ids[:-1] + + """ + the latter code assumes trimmed_ids is not empty + so have to check the its element count + """ + if trimmed_ids.numel() == 0: + continue + + # if the prompt is a single (non-pad) token, regenerate from bos + if len(batch_ids[batch_ids != pad_token_id]) == 1: + trimmed_ids[-1] = bos_token_id + + input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config) + + return input_ids + + def _dola_decoding( + self, + input_ids: torch.LongTensor, + dola_layers: Union[str, List[int]], + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: "BaseStreamer", + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **dola decoding** and can be + used for decoder-only text models. + The method is based on the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language + Models" (https://arxiv.org/abs/2309.03883) in ICLR 2024. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + dola_layers (`Union[str, List[int]]`): + The candidate layers used in contrasting layers of DoLa. It can be either 1) 'low' or 'high', which + means the lower part or higher part of the model layers, respectively, or 2) a list of layer indices + to be used for candidate layers. The 0-th layer is the word embedding layer of the model. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] + or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + + if self.config.is_encoder_decoder: + raise ValueError("DoLa decoding is only available for decoder-only models.") + # init values + + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # keep track of which sequences are already finished + batch_size = input_ids.shape[0] + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + this_peer_finished = False + + # prepare layers for DoLa decoding + final_layer = self.config.get_text_config().num_hidden_layers + # if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, + # as the early exit from word embeddings will become identity function + # if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th + # layer otherwise. Notice that DoLa does not help shallow models much. + if not self.config.tie_word_embeddings: + start_layer = 0 + elif final_layer > 2: + start_layer = 2 + elif final_layer == 2: + start_layer = 1 + else: + start_layer = 0 + + # For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)` + # are used for `'low'` and `'high'` layers, respectively. + # For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for + # `'low'` and `'high'` layers, respectively. + if isinstance(dola_layers, str) and dola_layers == "low": + if start_layer == final_layer // 2: + candidate_premature_layers = [start_layer] + else: + candidate_premature_layers = ( + list(range(start_layer, final_layer // 2, 2)) + if final_layer <= 40 + else list(range(start_layer, 20, 2)) + ) + elif isinstance(dola_layers, str) and dola_layers == "high": + candidate_premature_layers = ( + list(range(final_layer // 2, final_layer, 2)) + if final_layer <= 40 + else list(range(final_layer - 20, final_layer, 2)) + ) + # Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers. + elif isinstance(dola_layers, list): + candidate_premature_layers = [i for i in dola_layers if i < final_layer] + else: + raise ValueError("dola_layers must be either 'low', 'high' or a list of integers.") + + lm_head = self.get_output_embeddings() + if lm_head is None: + raise ValueError("DoLa is not supported for models that don't have output embeddings.") + + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=True, + ) + + # .float() is needed to retain precision for later logits manipulations + final_layer_next_token_logits = outputs.logits[:, -1, :].detach().clone().float() + final_logits = outputs.logits[:, -1, :].float() + candidate_premature_logits = {} + for candidate_premature_layer in candidate_premature_layers: + candidate_premature_logits[candidate_premature_layer] = lm_head( + outputs.hidden_states[candidate_premature_layer][:, -1, :] + ).to(final_logits.device) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if synced_gpus and this_peer_finished: + continue + + next_token_logits = _dola_select_contrast( + candidate_premature_layers, candidate_premature_logits, final_logits + ) + next_token_logits = next_token_logits.to(input_ids.device) + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (final_layer_next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + if do_sample: # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: # argmax + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + + # stop when each sentence is finished + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + @torch.no_grad() + def _contrastive_search( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **contrastive search** and can + be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] + or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + top_k = generation_config.top_k + penalty_alpha = generation_config.penalty_alpha + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + sequential = generation_config.low_memory + + # init attention / hidden states / scores tuples + raw_logits = () if (return_dict_in_generate and output_logits) else None + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size = input_ids.shape[0] + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + # Create cosine_matrix_mask based on the attention_mask + cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long) + if self.config.is_encoder_decoder: + if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None: + cosine_matrix_mask = model_kwargs["decoder_attention_mask"] + else: + cosine_matrix_mask = model_kwargs["attention_mask"] + cosine_matrix_mask = cosine_matrix_mask.repeat_interleave(top_k, dim=0) + + this_peer_finished = False + + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; + # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step + if model_kwargs.get("past_key_values") is None or ( + isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache)) + and model_kwargs["past_key_values"].get_seq_length() == 0 + ): + # prepare inputs + model_kwargs["use_cache"] = True + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save + # the `encoder_outputs` + outputs = self( + **model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions + ) + + # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with + # previous tokens) + if self.config.is_encoder_decoder: + last_hidden_states = outputs.decoder_hidden_states[-1] + else: + last_hidden_states = outputs.hidden_states[-1] + + # next logit for contrastive search to select top-k candidate tokens + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration + # (the clone itself is always small) + # .float() is needed to retain precision for later logits manipulations + logit_for_next_step = outputs.logits[:, -1, :].clone().float() + logit_for_next_step = logit_for_next_step.to(input_ids.device) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + + if not sequential: + # Expands model inputs top_k times, for batched forward passes (akin to beam search). + _, model_kwargs = self._expand_inputs_for_generation( + expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + + past_key_values = model_kwargs.get("past_key_values") + if past_key_values is None: + raise ValueError( + f"{self.__class__.__name__} does not support caching and therefore **can't** be used " + "for contrastive search." + ) + elif ( + not isinstance(past_key_values[0], (tuple, torch.Tensor)) + or past_key_values[0][0].shape[0] != batch_size + ): + raise ValueError( + f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be " + "used for contrastive search without further modifications." + ) + + # contrastive_search main logic start: + # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by + # degeneration penalty + processed_logit_for_next_step = logits_processor(input_ids, logit_for_next_step) + next_probs = nn.functional.softmax(processed_logit_for_next_step, dim=-1) + + top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_logits: + raw_logits += (logit_for_next_step,) + if output_scores: + scores += (processed_logit_for_next_step,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # This is needed to properly delete outputs.logits which may be very large for this first iteration + # Otherwise a reference to outputs.logits is kept all along until after the next call to self.forward() + del outputs + + if not sequential: + # Replicates the new past_key_values to match the `top_k` candidates + past = model_kwargs["past_key_values"] + # If it is a static cache, modify it in-place layer after layer to save memory + if isinstance(past, DynamicCache) or ( + isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache) + ): + past.batch_repeat_interleave(top_k) + else: + new_key_values = [] + for layer in past: + items = [] + # item is either the key or the value matrix + for item in layer: + items.append(item.repeat_interleave(top_k, dim=0)) + new_key_values.append(tuple(items)) + + past = tuple(new_key_values) + + model_kwargs["past_key_values"] = past + + if sequential: + all_outputs = [] + for i in range(top_k): + # compute the candidate tokens by the language model and collect their hidden_states + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs) + + outputs = self( + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions, + ) + if isinstance(outputs["past_key_values"], DynamicCache) or ( + isinstance(outputs["past_key_values"], EncoderDecoderCache) + and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache) + ): + # Remove past K-V from output since we don't need to stack later + outputs["past_key_values"] = None + # Remove last token from past K-V since we don't want to append it at this point + model_kwargs["past_key_values"].crop(-1) + + all_outputs.append(outputs) + outputs = stack_model_outputs(all_outputs, self.config.get_text_config()) + + else: + # compute the candidate tokens by the language model and collect their hidden_states + # assembles top_k_ids into batch of size k + next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) + + outputs = self( + **next_model_inputs, + return_dict=True, + output_hidden_states=True, + output_attentions=output_attentions, + ) + + # This is essential to avoid having a last reference to the big past K-V and double the necessary memory + # in the next loop + del next_model_inputs + + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states + + # .float() is needed to retain precision for later logits manipulations + logits = outputs.logits[:, -1, :].float() + context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) + + # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the + # model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't + # introduce (noticeable) slowdowns on single-device runs. + selected_idx = _ranking_fast( + context_hidden, next_hidden, top_k_probs, cosine_matrix_mask, penalty_alpha, top_k + ) + cosine_matrix_mask = torch.cat( + [cosine_matrix_mask, cosine_matrix_mask.new_ones((cosine_matrix_mask.shape[0], 1))], dim=-1 + ) + selected_idx = selected_idx.to("cpu") + + # This will be used instead of the previous inneficient torch.stack(torch.split()) + augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)]) + + # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing + # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores + # (model confidence minus degeneration penalty); (6) decoder hidden_states + next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] + next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) + next_hidden = next_hidden[range(batch_size), selected_idx, :] + last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) + + next_decoder_hidden_states = () + for layer in full_hidden_states: + layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :] + next_decoder_hidden_states += (layer,) + + # generate past_key_values cache of only the selected token + if sequential: + next_model_input = self.prepare_inputs_for_generation( + top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs + ) + + selected_outputs = self( + **next_model_input, + return_dict=True, + output_hidden_states=False, + output_attentions=False, + ) + next_past_key_values = selected_outputs["past_key_values"] + + else: + _, next_past_key_values = self._extract_past_from_model_output(outputs) + # Do it in-place layer per layer to save memory + if isinstance(next_past_key_values, DynamicCache) or ( + isinstance(next_past_key_values, EncoderDecoderCache) + and isinstance(next_past_key_values.self_attention_cache, DynamicCache) + ): + next_past_key_values.batch_select_indices(augmented_idx) + else: + new_key_values = [] + for layer in next_past_key_values: + items = [] + # item is either the key or the value matrix + for item in layer: + items.append(item[augmented_idx, ...]) + new_key_values.append(tuple(items)) + + next_past_key_values = tuple(new_key_values) + + logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] + logit_for_next_step = logit_for_next_step.to(input_ids.device) + + # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration + if self.config.is_encoder_decoder: + next_step_cross_attentions = () + next_step_decoder_attentions = () + if output_attentions: + for layer in outputs.cross_attentions: + layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] + next_step_cross_attentions += (layer,) + for layer in outputs.decoder_attentions: + layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] + next_step_decoder_attentions += (layer,) + outputs = Seq2SeqLMOutput( + past_key_values=next_past_key_values, + decoder_hidden_states=next_decoder_hidden_states, + decoder_attentions=next_step_decoder_attentions or None, + cross_attentions=next_step_cross_attentions or None, + ) + else: + next_step_attentions = () + if output_attentions: + for layer in outputs.attentions: + layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] + next_step_attentions += (layer,) + outputs = CausalLMOutputWithPast( + past_key_values=next_past_key_values, + hidden_states=next_decoder_hidden_states, + attentions=next_step_attentions or None, + ) + # contrastive_search main logic end + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if synced_gpus and this_peer_finished: + continue + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + + # stop when each sentence is finished + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + # Contrastive search works by forward looking at the next token, so we need to exclude it from + # `past_key_values` to be consistent with the other decoding methods + if model_kwargs.get("past_key_values") is not None: + if isinstance(model_kwargs["past_key_values"], DynamicCache) or ( + isinstance(model_kwargs["past_key_values"], EncoderDecoderCache) + and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache) + ): + model_kwargs["past_key_values"].crop(-1) + else: + past_key_values = [] + for layer in model_kwargs["past_key_values"]: + layer_past_key_values = [] + for item in layer: + layer_past_key_values.append(item[..., :-1, :]) + past_key_values.append(tuple(layer_past_key_values)) + model_kwargs["past_key_values"] = tuple(past_key_values) + + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length + ): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + # forward pass to get next token + outputs = self(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if synced_gpus and this_peer_finished: + continue + + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits.clone()[:, -1, :].float() + next_token_logits = next_token_logits.to(input_ids.device) + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + def _temporary_reorder_cache(self, past_key_values, beam_idx): + """ + Temporary function to handle the different types of cache reordering processes while we roll out `Cache`. + + TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need + for this function, with `Cache.reorder_cache` being the sole remaining code path + """ + model_class = self.__class__.__name__.lower() + # Exception 1: code path for models using the legacy cache format + if isinstance(past_key_values, (tuple, list)): + past_key_values = self._reorder_cache(past_key_values, beam_idx) + # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their + # cache format is standardized, to avoid adding complexity to the codebase. + elif "gptbigcode" in model_class: + if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)): + raise ValueError( + f"Using an unsupported cache format with {model_class}. Currently, it only supports the " + "legacy tuple format or `DynamicCache`" + ) + past_key_values = self._reorder_cache(past_key_values, beam_idx) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # Standard code path: use the `Cache.reorder_cache` + else: + past_key_values.reorder_cache(beam_idx) + return past_key_values + + def _beam_search( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + **model_kwargs, + ) -> Union[GenerateBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **beam search decoding** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + beam_scorer (`BeamScorer`): + An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`: + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config._pad_token_tensor + eos_token_id = generation_config._eos_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + sequential = generation_config.low_memory + do_sample = generation_config.do_sample + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + beam_indices = ( + tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None + ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens + # of the first beam are considered to avoid sampling the exact same tokens across all beams. + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False + + decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + + # print("model_kwargs: ", model_kwargs) + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + # if sequential is True, split the input to batches of batch_size and run sequentially + if sequential: + if any( + model_name in self.__class__.__name__.lower() + for model_name in [ + "fsmt", + "reformer", + "ctrl", + "gpt_bigcode", + "transo_xl", + "xlnet", + "cpm", + "jamba", + ] + ): + raise RuntimeError( + f"Currently generation for {self.__class__.__name__} is not supported " + f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature." + ) + + inputs_per_sub_batches = _split_model_inputs( + model_inputs, + split_size=batch_size, + full_batch_size=batch_beam_size, + config=self.config.get_text_config(), + ) + outputs_per_sub_batch = [ + self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches + ] + + outputs = stack_model_outputs(outputs_per_sub_batch, self.config.get_text_config()) + + else: # Unchanged original behavior + outputs = self(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue + + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + # .float() is needed to retain precision for later logits manipulations + next_token_logits = outputs.logits[:, -1, :].clone().float() + next_token_logits = next_token_logits.to(input_ids.device) + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + + next_token_scores_processed = logits_processor(input_ids, next_token_scores) + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores_processed,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + + # Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1 + # non eos token per beam. + n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 + n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams + if do_sample: + # import time + # start = time.time() + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep) + next_token_scores = torch.gather(next_token_scores, -1, next_tokens) + next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, _indices) + # print("*"*20, probs.shape, n_tokens_to_keep, next_token_scores.shape, next_tokens.shape) + # print("*"*20, time.time() - start) + else: + next_token_scores, next_tokens = torch.topk( + next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory + # (that way the memory peak does not include outputs.logits) + del outputs + + if model_kwargs.get("past_key_values", None) is not None: + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + model_kwargs["past_key_values"], beam_idx + ) + + if return_dict_in_generate and output_scores: + beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) + + # increase cur_len + cur_len = cur_len + 1 + + if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): + this_peer_finished = True + + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + + if self.config.is_encoder_decoder: + return GenerateBeamEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateBeamDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return sequence_outputs["sequences"] + + def _group_beam_search( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + **model_kwargs, + ): + r""" + Generates sequences of token ids for models with a language modeling head using **diverse beam search + decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + beam_scorer (`BeamScorer`): + An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + model_kwargs: + Additional model specific kwargs that will be forwarded to the `forward` function of the model. If + model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config._pad_token_tensor + eos_token_id = generation_config._eos_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_size = len(beam_scorer._beam_hyps) // num_beam_groups + device = input_ids.device + + batch_beam_size, cur_len = input_ids.shape + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + if return_dict_in_generate and output_scores: + beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] + else: + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in + # the same group don't produce same tokens every time. + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False + + decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + outputs = self(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue + + if output_scores: + processed_score = torch.zeros_like(outputs.logits[:, -1, :]) + if output_logits: + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + raw_logit_score = outputs.logits[:, -1, :].clone() + raw_logit_score = raw_logit_score.to(input_ids.device) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of current group only + # No need to clone() the logits here as they will not retain outputs.logits at the end of the loop + # .float() is needed to retain precision for later logits manipulations + next_token_logits = outputs.logits[batch_group_indices, -1, :].float() + next_token_logits = next_token_logits.to(input_ids.device) + + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * group_size, vocab_size) + vocab_size = next_token_scores.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + if output_scores: + processed_score[batch_group_indices] = next_token_scores_processed + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. + n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 + next_token_scores, next_tokens = torch.topk( + next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + group_index=beam_group_idx, + decoder_prompt_len=decoder_prompt_len, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + if return_dict_in_generate and output_scores: + beam_indices[beam_group_idx] = tuple( + beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) + ) + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + + group_start_idx + + (beam_idx % group_size) + ) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (processed_score,) + if output_logits: + raw_logits += (raw_logit_score,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory + # (that way the memory peak does not include outputs.logits) + del outputs + + if model_kwargs.get("past_key_values", None) is not None: + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + model_kwargs["past_key_values"], reordering_indices + ) + + # increase cur_len + cur_len = cur_len + 1 + + if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): + this_peer_finished = True + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + + if self.config.is_encoder_decoder: + return GenerateBeamEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateBeamDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return sequence_outputs["sequences"] + + def _constrained_beam_search( + self, + input_ids: torch.LongTensor, + constrained_beam_scorer: ConstrainedBeamSearchScorer, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + **model_kwargs, + ) -> Union[GenerateBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **constrained beam search + decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + constrained_beam_scorer (`ConstrainedBeamSearchScorer`): + A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation, while satisfying a list of positive constraints. For more information, the + documentation of [`ConstrainedBeamSearchScorer`] should be read. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config._pad_token_tensor + eos_token_id = generation_config._eos_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + + batch_size = len(constrained_beam_scorer._beam_hyps) + num_beams = constrained_beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + beam_indices = ( + tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None + ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens + # of the first beam are considered to avoid sampling the exact same tokens across all beams. + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False + + decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + outputs = self(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue + + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + # .float() is needed to retain precision for later logits manipulations + next_token_logits = outputs.logits[:, -1, :].clone().float() + next_token_logits = next_token_logits.to(input_ids.device) + next_token_scores = nn.functional.log_softmax( + next_token_logits, dim=-1 + ) # (batch_size * num_beams, vocab_size) + + next_token_scores_processed = logits_processor(input_ids, next_token_scores) + + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( + next_token_scores_processed + ) + + scores_for_all_vocab = next_token_scores.clone() + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + + # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. + n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 + next_token_scores, next_tokens = torch.topk( + next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True + ) + + next_indices = (next_tokens / vocab_size).long() + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = constrained_beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + scores_for_all_vocab, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory + # (that way the memory peak does not include outputs.logits) + del outputs + + if model_kwargs.get("past_key_values", None) is not None: + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + model_kwargs["past_key_values"], beam_idx + ) + + if return_dict_in_generate and output_scores: + beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) + + # increase cur_len + cur_len = cur_len + 1 + + if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): + this_peer_finished = True + + sequence_outputs = constrained_beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=beam_indices, + decoder_prompt_len=decoder_prompt_len, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + if self.config.is_encoder_decoder: + return GenerateBeamEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateBeamDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + logits=raw_logits, + beam_indices=sequence_outputs["beam_indices"], + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return sequence_outputs["sequences"] + + def _assisted_decoding( + self, + input_ids: torch.LongTensor, + candidate_generator: CandidateGenerator, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **greedy decoding** or + **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a + candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text + models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + candidate_generator (`CandidateGenerator`): + A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For + more information, the documentation of [`CandidateGenerator`] should be read. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + do_sample = generation_config.do_sample + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size = input_ids.shape[0] + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + this_peer_finished = False + is_first_iteration = True # to preserve the same API in the output as other generation methods + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + cur_len = input_ids.shape[-1] + + # 1. Fetch candidate sequences from a `CandidateGenerator` + candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) + + if candidate_logits is not None: + candidate_logits = candidate_logits.to(self.device) + + candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] + is_done_candidate = stopping_criteria(candidate_input_ids, None) + + # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain + # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, + # we use this forward pass to also pick the subsequent logits in the original model. + + # 2.1. Prepare the model inputs + candidate_kwargs = copy.copy(model_kwargs) + candidate_kwargs = _prepare_attention_mask( + candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder + ) + candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) + if "cache_position" in candidate_kwargs: + candidate_kwargs["cache_position"] = torch.cat( + ( + candidate_kwargs["cache_position"], + torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long), + ), + dim=0, + ) + + model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) + if "num_logits_to_keep" in model_inputs: + model_inputs["num_logits_to_keep"] = candidate_length + 1 + + # 2.2. Run a forward pass on the candidate sequence + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + outputs = self(**model_inputs) + + # 2.3. Process the new logits + # .float() is needed to retain precision for later logits manipulations + new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present + new_logits = new_logits.to(input_ids.device) + next_token_logits = new_logits.clone() + if len(logits_processor) > 0: + for i in range(candidate_length + 1): + new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) + + # 3. Select the accepted tokens. There are two possible cases: + # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) + # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf). + if do_sample and candidate_logits is not None: + valid_tokens, n_matches = _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + is_done_candidate, + ) + + # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the + # original model logits with the candidate tokens. We can keep the candidate tokens until the first + # mismatch, or until the max length is reached. + else: + if do_sample: + probs = new_logits.softmax(dim=-1) + selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] + else: + selected_tokens = new_logits.argmax(dim=-1) + + candidate_new_tokens = candidate_input_ids[:, cur_len:] + n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() + + # Ensure we don't generate beyond max_len or an EOS token + if is_done_candidate and n_matches == candidate_length: + n_matches -= 1 + valid_tokens = selected_tokens[:, : n_matches + 1] + + # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated + # by the model after the last candidate match is also valid, as it is generated from a correct sequence. + # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there + # is no match. + + # 4.1. Get the valid continuation, after the matching tokens + input_ids = torch.cat((input_ids, valid_tokens), dim=-1) + if streamer is not None: + streamer.put(valid_tokens.cpu()) + new_cur_len = input_ids.shape[-1] + + # 4.2. Discard past key values relative to unused assistant tokens + new_cache_size = new_cur_len - 1 + outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size) + + # 5. Update the candidate generation strategy if needed + candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + num_new_tokens=n_matches + 1, + ) + if synced_gpus and this_peer_finished: + continue + + # Store scores, attentions and hidden_states when required + # Assistant: modified to append one tuple element per token, as in the other generation methods. + if return_dict_in_generate: + newly_added_length = n_matches + 1 + if output_scores: + scores += tuple(new_logits[:, i, :] for i in range(newly_added_length)) + if output_logits: + raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length)) + + newly_added_length = new_cur_len if is_first_iteration else newly_added_length + if output_attentions: + if self.config.is_encoder_decoder: + cross_attentions = _split_model_outputs( + cross_attentions, outputs.cross_attentions, cur_len, newly_added_length + ) + decoder_attentions = _split_model_outputs( + decoder_attentions, + outputs.decoder_attentions, + cur_len, + newly_added_length, + is_decoder_attention=True, + ) + # some (V)LLMs have hard requirement on SDPA and thus never return attn + elif outputs.attentions[0] is not None: + decoder_attentions = _split_model_outputs( + decoder_attentions, + outputs.attentions, + cur_len, + newly_added_length, + is_decoder_attention=True, + ) + if output_hidden_states: + if self.config.is_encoder_decoder: + decoder_hidden_states = _split_model_outputs( + decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length + ) + else: + decoder_hidden_states = _split_model_outputs( + decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length + ) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + is_first_iteration = False + + if streamer is not None: + streamer.end() + + if ( + hasattr(candidate_generator, "assistant_model") + and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic" + ): + candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( + candidate_generator.num_assistant_tokens + ) + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids + + +def _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + is_done_candidate, +): + """ + Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns + the selected tokens, as well as the number of candidate matches. + + NOTE: Unless otherwise stated, the variable names match those in the paper. + """ + new_candidate_input_ids = candidate_input_ids[:, -candidate_length:] + # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens + # selected by the assistant, respectively. + q = candidate_logits.softmax(dim=-1) + q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1) + p = new_logits.softmax(dim=-1) + p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1) + probability_ratio = p_i / q_i + + # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller + # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio + # (= keep with p = probability_ratio). Keep all the tokens until the first rejection + r_i = torch.rand_like(probability_ratio) + is_accepted = r_i <= probability_ratio + n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 + + # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) + if is_done_candidate and n_matches == candidate_length: + # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model + # due to acceptance on EOS we fix `n_matches` + n_matches -= 1 + valid_tokens = new_candidate_input_ids[:, : n_matches + 1] + else: + # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. + gamma = candidate_logits.shape[1] + p_n_plus_1 = p[:, n_matches, :] + if n_matches < gamma: + q_n_plus_1 = q[:, n_matches, :] + p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0) + p_prime.div_(p_prime.sum()) + else: + p_prime = p_n_plus_1 + t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] + + # The selected tokens include the matches (if any) plus the next sampled tokens + if n_matches > 0: + valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1) + else: + valid_tokens = t + + return valid_tokens, n_matches + + +def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): + """ + Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple + where each member corresponds to a single generated token. + """ + # Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the + # prompt. + if len(outputs) == 0: + new_tuple = () + for layer in new_outputs: + last_dim_size = cur_len if is_decoder_attention else layer.shape[-1] + new_tuple += (layer[..., :cur_len, :last_dim_size],) + outputs += (new_tuple,) + # The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly + cur_len += 1 + added_len -= cur_len + + for i in range(added_len): + new_tuple = () + for layer in new_outputs: + last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1] + new_tuple += (layer[..., i : i + 1, :last_dim_size],) + outputs += (new_tuple,) + return outputs + + +def _ranking_fast( + context_hidden: torch.FloatTensor, + next_hidden: torch.FloatTensor, + next_top_k_probs: torch.FloatTensor, + cosine_matrix_mask: torch.LongTensor, + alpha: float, + beam_width: int, +) -> torch.FloatTensor: + """ + Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described + in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each + row in the batch. + """ + norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) + norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) + cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S] + + # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions) + # Using a large negative value for masked positions + cosine_matrix_mask = cosine_matrix_mask.to(dtype=cosine_matrix.dtype) + cosine_matrix_mask = (1 - cosine_matrix_mask) * torch.finfo(cosine_matrix.dtype).min + cosine_matrix = cosine_matrix + cosine_matrix_mask + + degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K] + next_top_k_probs = next_top_k_probs.view(-1) # [B*K] + contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty + contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] + _, selected_idx = contrastive_score.max(dim=-1) # [B] + return selected_idx + + +def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int = None): + """ + Takes care of three cases: + 1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim + 2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and + return a list of tuples + 3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and + return a list of tuples of tuples + (see documentation of ModelOutput) + """ + if data is None: + return [None] * (full_batch_size // split_size) + if isinstance(data, torch.Tensor): + return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] + # New cache format + elif isinstance(data, DynamicCache) or ( + isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache) + ): + return data.batch_split(full_batch_size, split_size, num_hidden_layers) + elif isinstance(data, tuple): + # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) + if isinstance(data[0], tuple): + return [ + tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data) + for i in range(0, full_batch_size, split_size) + ] + + else: + return [ + tuple(sub_tensor[i : i + split_size] for sub_tensor in data) + for i in range(0, full_batch_size, split_size) + ] + else: + raise TypeError(f"Unexpected attribute type: {type(data)}") + + +def _split_model_inputs( + model_input: Union[ModelOutput, Dict], split_size: int, full_batch_size: int, config: PretrainedConfig +) -> List[Union[ModelOutput, Dict]]: + """ + Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split + size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from + previous forward pass. + """ + # Edge case: if model_input is None, return a list of Nones + # this happens with Whisper where encoder_outputs is None + if model_input is None: + return [model_input] * (full_batch_size // split_size) + # Infer the class from the object + model_output_cls = type(model_input) + if (full_batch_size % split_size) != 0: + raise ValueError("`full_batch_size` must be divisible by `split_size`") + + if split_size > full_batch_size: + raise ValueError("`split_size` must be smaller or equal to `full_batch_size`") + + # Helper function to split tensors or tuples of tensors + + # Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them + keys = ( + model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys() + ) + # We only keep keys that are in the model_input + keys = [k for k in keys if k in model_input] + # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a + # ModelOutput object. + # bool should not be split but replicated for each split + bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] + keys_to_ignore = ["cache_position", "encoder_outputs", "num_logits_to_keep"] + non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] + + num_hidden_layers = config.get_text_config().num_hidden_layers + + # we split the tensors and tuples of tensors + data_split_list = [ + {k: _split(model_input[k], full_batch_size, num_hidden_layers, split_size)[i] for k in non_bool_keys} + for i in range(full_batch_size // split_size) + ] + # bool values are the same and replicated for each split + bool_data = {k: model_input[k] for k in bool_keys} + # encoder_outputs is a ModelOutput object and should be split by its own + if "encoder_outputs" in model_input: + encoder_outputs_split = _split_model_inputs( + model_input["encoder_outputs"], split_size, full_batch_size, config.get_text_config() + ) + data_split_list = [ + {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list) + ] + # num_logits_to_keep should be replicated for each split, similar to bool values + if "num_logits_to_keep" in model_input: + data_split_list = [ + {**data_split, "num_logits_to_keep": model_input["num_logits_to_keep"]} for data_split in data_split_list + ] + + # Convert each dictionary in the list to an object of the inferred class + split_model_inputs: List[Union[ModelOutput, Dict]] = [ + model_output_cls(**data_split, **bool_data) for data_split in data_split_list + ] + + return split_model_inputs + + +def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConfig) -> ModelOutput: + """ + Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the + specific ModelOutput subclass from the list provided. + """ + if not model_outputs: + raise ValueError("Input list is empty.") + + # Infer the class from the first object in the list + model_output_cls = type(model_outputs[0]) + num_hidden_layers = config.get_text_config().num_hidden_layers + + # Ensure all objects are of the same type + if not all(isinstance(obj, model_output_cls) for obj in model_outputs): + raise ValueError("All elements in the list should be of the same type.") + + # Helper function to concat tensors or tuples of tensors + def _concat(data): + """ + Reverse of `_split` function above. + """ + if any(data is None for data in data): + return None + if isinstance(data[0], torch.Tensor): + return torch.cat(data, dim=0) + # New cache format + elif isinstance(data[0], DynamicCache): + return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) + elif isinstance(data[0], EncoderDecoderCache): + return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) + elif isinstance(data[0], tuple): + # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) + if isinstance(data[0][0], tuple): + return tuple( + tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0]))) + for i in range(len(data[0])) + ) + else: + return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0]))) + elif isinstance(data[0], (int, float)): + # If the elements are integers or floats, return a tensor + return torch.tensor(data) + else: + raise TypeError(f"Unexpected attribute type: {type(data[0])}") + + # Use a dictionary comprehension to gather attributes from all objects and concatenate them + concatenated_data = { + k: _concat([getattr(model_output, k) for model_output in model_outputs]) + for k in model_output_cls.__dataclass_fields__.keys() + } + + # Return a new object of the inferred class with the concatenated attributes + return model_output_cls(**concatenated_data) + + +def _relative_top_filter( + scores: torch.FloatTensor, + baseline_scores: torch.FloatTensor, + relative_top: float = 0.1, + filter_value: float = -float("Inf"), + base_filter_value=-1e-3, + min_tokens_to_keep: int = 1, +) -> torch.FloatTensor: + """ + Reference: https://github.com/XiangLi1999/ContrastiveDecoding/blob/170e9142e92159c1237d731e240f5eb14aabf428/transformers/src/transformers/generation_logits_process.py#L235 + Apply filtering to only keep tokens with a probability above a certain threshold. The threshold is defined as `relative_top` * max probability in the distribution. + """ + scores_normalized = scores.log_softmax(dim=-1) + baseline_scores_normalized = baseline_scores.log_softmax(dim=-1) + sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True) + min_thresh = sorted_logits[..., min_tokens_to_keep - 1] + probs_max = torch.max(scores_normalized, dim=-1).values + probs_thresh = probs_max + np.log(relative_top) + probs_thresh = torch.min(min_thresh, probs_thresh) + probs_thresh = probs_thresh.unsqueeze(-1) + baseline_scores_normalized[scores_normalized < probs_thresh] = base_filter_value + scores_normalized[scores_normalized < probs_thresh] = filter_value + return scores_normalized, baseline_scores_normalized + + +def _dola_select_contrast( + candidate_premature_layers: List[int], + candidate_premature_logits: Dict[int, torch.FloatTensor], + final_logits: torch.FloatTensor, +) -> torch.FloatTensor: + if len(candidate_premature_layers) == 1: + base_logits = candidate_premature_logits[candidate_premature_layers[0]] + final_logits, base_logits = _relative_top_filter(final_logits, base_logits) + logits = final_logits - base_logits + return logits + + # 1. Stacking all premature_layers into a new dimension + stacked_premature_layers = torch.stack([candidate_premature_logits[i] for i in candidate_premature_layers], dim=0) + + # 2. Calculate the softmax values for mature_layer and all premature_layers + # shape: (batch_size, vocab_size) + softmax_mature_layer = F.softmax(final_logits, dim=-1) + # shape: (num_premature_layers, batch_size, vocab_size) + softmax_premature_layers = F.softmax(stacked_premature_layers, dim=-1) + + # 3. Calculate the average distribution + # shape: (num_premature_layers, batch_size, vocab_size) + avg_dist = 0.5 * (softmax_mature_layer[None, :, :] + softmax_premature_layers) + + # 4. Calculate log-softmax for the KL divergence + # shape: (batch_size, vocab_size) + log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1) + # shape: (num_premature_layers, batch_size, vocab_size) + log_softmax_premature_layers = F.log_softmax(stacked_premature_layers, dim=-1) + + # 5. Calculate the KL divergences and then the JS divergences + # shape: (num_premature_layers, batch_size) + kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], avg_dist, reduction="none").mean(-1) + # shape: (num_premature_layers, batch_size) + kl2 = F.kl_div(log_softmax_premature_layers, avg_dist, reduction="none").mean(-1) + js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size) + + # 6. Reduce the batchmean + js_divs = js_divs.mean(-1) # shape: (num_premature_layers,) + premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())] + + base_logits = candidate_premature_logits[premature_layer] + final_logits, base_logits = _relative_top_filter(final_logits, base_logits) + logits = final_logits - base_logits + return logits diff --git a/indextts/gpt/transformers_gpt2.py b/indextts/gpt/transformers_gpt2.py new file mode 100755 index 0000000000000000000000000000000000000000..ab7fa96d8218667fa13ada4bdb83b1d65271c0bf --- /dev/null +++ b/indextts/gpt/transformers_gpt2.py @@ -0,0 +1,1878 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +import transformers + +from indextts.gpt.transformers_generation_utils import GenerationMixin +from indextts.gpt.transformers_modeling_utils import PreTrainedModel +from transformers.modeling_utils import SequenceSummary + +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +# from transformers.modeling_utils import PreTrainedModel, SequenceSummary + +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + get_torch_version, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.gpt2.configuration_gpt2 import GPT2Config + + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with torch.amp.autocast(query.device.type, enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = _flash_attention_forward( + query, + key, + value, + attention_mask, + query_length, + dropout=attn_dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + +class GPT2SdpaAttention(GPT2Attention): + """ + GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass + to adapt to the SDPA API. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__ + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + bsz, q_len, _ = hidden_states.size() + + # Initial attention projections + is_cross_attention = encoder_hidden_states is not None + if is_cross_attention: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + # Optional kv caching + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA + if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=self.attn_dropout.p if self.training else 0.0, + is_causal=is_causal, + ) + + # Reshape outputs + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.embed_dim) + + # Final projection + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + return attn_output, present, None + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention} + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, *optional*): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - openai-community/gpt2: 12 + - openai-community/gpt2-medium: 24 + - openai-community/gpt2-large: 36 + - openai-community/gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with openai-community/gpt2-large: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + _supports_param_buffer_assignment = False + + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + # Attention mask. + _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None + attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif _use_sdpa: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(batch_size, input_shape[-1]), + inputs_embeds=inputs_embeds, + past_key_values_length=past_length, + ) + else: + if attention_mask is not None: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + if _use_sdpa: + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + elif not self._attn_implementation == "flash_attention_2": + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" + " 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" + " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" + " own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=[ + "Lead", + "Lead", + "Lead", + "Position", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + ], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + real_checkpoint=_CHECKPOINT_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/indextts/gpt/transformers_modeling_utils.py b/indextts/gpt/transformers_modeling_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..39c28d080c3646c43fcdedcb5dd9a9ec361a045c --- /dev/null +++ b/indextts/gpt/transformers_modeling_utils.py @@ -0,0 +1,5525 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import copy +import functools +import gc +import importlib.metadata +import inspect +import itertools +import json +import os +import re +import shutil +import tempfile +import warnings +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial, wraps +from threading import Thread +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from zipfile import is_zipfile + +import torch +from huggingface_hub import split_torch_state_dict_into_shards +from packaging import version +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss, Identity +from torch.utils.checkpoint import checkpoint +from transformers.activations import get_activation +from transformers.configuration_utils import PretrainedConfig +from transformers.dynamic_module_utils import custom_object_save +from transformers.generation import GenerationConfig +import transformers +from indextts.gpt.transformers_generation_utils import GenerationMixin +from transformers.generation import GenerationConfig + + +from transformers.integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled +from transformers.loss.loss_utils import LOSS_MAPPING +from transformers.pytorch_utils import ( # noqa: F401 + Conv1D, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + id_tensor_storage, + is_torch_greater_or_equal_than_1_13, + prune_conv1d_layer, + prune_layer, + prune_linear_layer, +) +from transformers.quantizers import AutoHfQuantizer, HfQuantizer +from transformers.quantizers.quantizers_utils import get_module_from_name +from transformers.safetensors_conversion import auto_conversion +from transformers.utils import ( + ACCELERATE_MIN_VERSION, + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, + CONFIG_NAME, + DUMMY_INPUTS, + FLAX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + TF2_WEIGHTS_NAME, + TF_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + ContextManagers, + ModelOutput, + PushToHubMixin, + cached_file, + copy_func, + download_url, + extract_commit_hash, + has_file, + is_accelerate_available, + is_bitsandbytes_available, + is_flash_attn_2_available, + is_offline_mode, + is_optimum_available, + is_peft_available, + is_remote_url, + is_safetensors_available, + is_torch_sdpa_available, + is_torch_xla_available, + logging, + replace_return_docstrings, + strtobool, +) +from transformers.utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files +from transformers.utils.import_utils import ( + ENV_VARS_TRUE_VALUES, + is_sagemaker_mp_enabled, + is_torch_fx_proxy, + is_torchdynamo_compiling, +) +from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod + + +XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() +XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() + + +if is_accelerate_available(): + from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights + from accelerate.hooks import add_hook_to_module + from accelerate.utils import ( + check_tied_parameters_on_same_device, + extract_model_from_parallel, + find_tied_parameters, + get_balanced_memory, + get_max_memory, + load_offloaded_weights, + offload_weight, + save_offload_index, + set_module_tensor_to_device, + ) + + accelerate_version = version.parse(importlib.metadata.version("accelerate")) + if accelerate_version >= version.parse("0.31"): + from accelerate.utils.modeling import get_state_dict_from_offload + +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.torch import load_file as safe_load_file + from safetensors.torch import save_file as safe_save_file + +logger = logging.get_logger(__name__) + + +_init_weights = True + + +def is_fsdp_enabled(): + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1 + and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 + ) + + +def is_local_dist_rank_0(): + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and int(os.environ.get("LOCAL_RANK", -1)) == 0 + ) + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") +else: + IS_SAGEMAKER_MP_POST_1_10 = False + +if is_peft_available(): + from transformers.utils import find_adapter_config_file + +TORCH_INIT_FUNCTIONS = { + "uniform_": nn.init.uniform_, + "normal_": nn.init.normal_, + "trunc_normal_": nn.init.trunc_normal_, + "constant_": nn.init.constant_, + "xavier_uniform_": nn.init.xavier_uniform_, + "xavier_normal_": nn.init.xavier_normal_, + "kaiming_uniform_": nn.init.kaiming_uniform_, + "kaiming_normal_": nn.init.kaiming_normal_, + "uniform": nn.init.uniform, + "normal": nn.init.normal, + "xavier_uniform": nn.init.xavier_uniform, + "xavier_normal": nn.init.xavier_normal, + "kaiming_uniform": nn.init.kaiming_uniform, + "kaiming_normal": nn.init.kaiming_normal, +} + + +@contextmanager +def no_init_weights(_enable=True): + """ + Context manager to globally disable weight initialization to speed up loading large models. + + TODO(Patrick): Delete safety argument `_enable=True` at next major version. . + """ + global _init_weights + old_init_weights = _init_weights + + if _enable: + _init_weights = False + + def _skip_init(*args, **kwargs): + pass + + # # Save the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, _skip_init) + try: + yield + finally: + _init_weights = old_init_weights + if _enable: + # # Restore the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, init_func) + + +def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]): + try: + return next(parameter.parameters()).device + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): + """ + Returns the first parameter dtype (can be non-floating) or asserts if none were found. + """ + try: + return next(parameter.parameters()).dtype + except StopIteration: + # For nn.DataParallel compatibility in PyTorch > 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): + """ + Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. + """ + last_dtype = None + for t in parameter.parameters(): + last_dtype = t.dtype + if t.is_floating_point(): + # Adding fix for https://github.com/pytorch/xla/issues/4152 + # Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1 + # and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf + # NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo + if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available(): + return torch.bfloat16 + if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available(): + if t.dtype == torch.float: + return torch.bfloat16 + if t.dtype == torch.double: + return torch.float32 + return t.dtype + + if last_dtype is not None: + # if no floating dtype was found return whatever the first dtype is + return last_dtype + + # For nn.DataParallel compatibility in PyTorch > 1.5 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + last_tuple = None + for tuple in gen: + last_tuple = tuple + if tuple[1].is_floating_point(): + return tuple[1].dtype + + if last_tuple is not None: + # fallback to the last dtype + return last_tuple[1].dtype + + # fallback to buffer dtype + for t in parameter.buffers(): + last_dtype = t.dtype + if t.is_floating_point(): + return t.dtype + return last_dtype + + +def get_state_dict_float_dtype(state_dict): + """ + Returns the first found floating dtype in `state_dict` or asserts if none were found. + """ + for t in state_dict.values(): + if t.is_floating_point(): + return t.dtype + + raise ValueError("couldn't find any floating point dtypes in state_dict") + + +def get_state_dict_dtype(state_dict): + """ + Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype. + """ + for t in state_dict.values(): + if t.is_floating_point(): + return t.dtype + + # if no floating dtype was found return whatever the first dtype is + else: + return next(state_dict.values()).dtype + + +def dtype_byte_size(dtype): + """ + Returns the size (in bytes) occupied by one parameter of type `dtype`. + + Example: + + ```py + >>> dtype_byte_size(torch.float32) + 4 + ``` + """ + if dtype == torch.bool: + return 1 / 8 + bit_search = re.search(r"[^\d](\d+)_?", str(dtype)) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 + + +def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): + """ + Checks if `model_to_load` supports param buffer assignment (such + as when loading in empty weights) by first checking + if the model explicitly disables it, then by ensuring that the state dict keys + are a subset of the model's parameters. + + Note: We fully disable this if we are using `deepspeed` + """ + if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: + return False + + if is_deepspeed_zero3_enabled(): + return False + + # Some models explicitly do not support param buffer assignment + if not getattr(model_to_load, "_supports_param_buffer_assignment", True): + logger.debug( + f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" + ) + return False + + # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype + first_key = list(model_to_load.state_dict().keys())[0] + if start_prefix + first_key in state_dict: + return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype + + # For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`) + return False + + +def shard_checkpoint( + state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME +): + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + + The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no + optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the + limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], + [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will + have a size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit + (like `"5MB"`). + weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`): + The name of the model save file. + """ + logger.warning( + "Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using " + "split_torch_state_dict_into_shards from huggingface_hub library" + ) + max_shard_size = convert_file_size_to_int(max_shard_size) + + sharded_state_dicts = [{}] + last_block_size = 0 + total_size = 0 + storage_id_to_block = {} + + for key, weight in state_dict.items(): + # when bnb serialization is used the weights in the state dict can be strings + # check: https://github.com/huggingface/transformers/pull/24416 for more details + if isinstance(weight, str): + continue + else: + storage_id = id_tensor_storage(weight) + + # If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block` + if storage_id in storage_id_to_block and weight.device != torch.device("meta"): + block_id = storage_id_to_block[storage_id] + sharded_state_dicts[block_id][key] = weight + continue + + weight_size = weight.numel() * dtype_byte_size(weight.dtype) + # If this weight is going to tip up over the maximal size, we split, but only if we have put at least one + # weight in the current shard. + if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0: + sharded_state_dicts.append({}) + last_block_size = 0 + + sharded_state_dicts[-1][key] = weight + last_block_size += weight_size + total_size += weight_size + storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1 + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {weights_name: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") + shard_file = shard_file.replace( + ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" + ) + shards[shard_file] = shard + for key in shard.keys(): + weight_map[key] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + + +def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): + """ + This is the same as + [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict) + but for a sharded checkpoint. + + This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being + loaded in the model. + + Args: + model (`torch.nn.Module`): The model in which to load the checkpoint. + folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint. + strict (`bool`, *optional`, defaults to `True`): + Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. + prefer_safe (`bool`, *optional*, defaults to `False`) + If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the + safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible. + + Returns: + `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields + - `missing_keys` is a list of str containing the missing keys + - `unexpected_keys` is a list of str containing the unexpected keys + """ + # Load the index + index_file = os.path.join(folder, WEIGHTS_INDEX_NAME) + safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME) + + index_present = os.path.isfile(index_file) + safe_index_present = os.path.isfile(safe_index_file) + + if not index_present and not (safe_index_present and is_safetensors_available()): + filenames = ( + (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) if is_safetensors_available() else (WEIGHTS_INDEX_NAME,) + ) + raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.") + + load_safe = False + if safe_index_present: + if prefer_safe: + if is_safetensors_available(): + load_safe = True # load safe due to preference + else: + logger.warning( + f"Cannot load sharded checkpoint at {folder} safely since safetensors is not installed!" + ) + elif not index_present: + load_safe = True # load safe since we have no other choice + + load_index = safe_index_file if load_safe else index_file + + with open(load_index, "r", encoding="utf-8") as f: + index = json.load(f) + + shard_files = list(set(index["weight_map"].values())) + + # If strict=True, error before loading any of the state dicts. + loaded_keys = index["weight_map"].keys() + model_keys = model.state_dict().keys() + missing_keys = [key for key in model_keys if key not in loaded_keys] + unexpected_keys = [key for key in loaded_keys if key not in model_keys] + if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0): + error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" + if len(missing_keys) > 0: + str_missing_keys = ",".join([f'"{k}"' for k in missing_keys]) + error_message += f"\nMissing key(s): {str_missing_keys}." + if len(unexpected_keys) > 0: + str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys]) + error_message += f"\nMissing key(s): {str_unexpected_keys}." + raise RuntimeError(error_message) + + weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg) + + for shard_file in shard_files: + state_dict = loader(os.path.join(folder, shard_file)) + model.load_state_dict(state_dict, strict=False) + + # Make sure memory is freed before we load the next state dict. + del state_dict + gc.collect() + + # Return the same thing as PyTorch load_state_dict function. + return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys) + + +def load_state_dict( + checkpoint_file: Union[str, os.PathLike], + is_quantized: bool = False, + map_location: Optional[Union[str, torch.device]] = None, + weights_only: bool = True, +): + """ + Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. + """ + if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): + # Check format of the archive + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_pretrained` method." + ) + return safe_load_file(checkpoint_file) + try: + if map_location is None: + if ( + ( + is_deepspeed_zero3_enabled() + and torch.distributed.is_initialized() + and torch.distributed.get_rank() > 0 + ) + or (is_fsdp_enabled() and not is_local_dist_rank_0()) + ) and not is_quantized: + map_location = "meta" + else: + map_location = "cpu" + extra_args = {} + # mmap can only be used with files serialized with zipfile-based format. + if ( + isinstance(checkpoint_file, str) + and map_location != "meta" + and version.parse(torch.__version__) >= version.parse("2.1.0") + and is_zipfile(checkpoint_file) + ): + extra_args = {"mmap": True} + weights_only_kwarg = {"weights_only": weights_only} if is_torch_greater_or_equal_than_1_13 else {} + return torch.load( + checkpoint_file, + map_location=map_location, + **weights_only_kwarg, + **extra_args, + ) + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read(7) == "version": + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " + f"at '{checkpoint_file}'. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + + +def set_initialized_submodules(model, state_dict_keys): + """ + Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state + dict. + """ + not_initialized_submodules = {} + for module_name, module in model.named_modules(): + loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")} + # When checking if the root module is loaded all state_dict_keys must be used. + if module_name == "": + loaded_keys = set(state_dict_keys) + if loaded_keys.issuperset(module.state_dict()): + module._is_hf_initialized = True + else: + not_initialized_submodules[module_name] = module + return not_initialized_submodules + + +def _end_ptr(tensor: torch.Tensor) -> int: + # extract the end of the pointer if the tensor is a slice of a bigger tensor + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() + else: + stop = tensor.data_ptr() + return stop + + +def _get_tied_weight_keys(module: nn.Module, prefix=""): + tied_weight_keys = [] + if getattr(module, "_tied_weights_keys", None) is not None: + names = [f"{prefix}.{k}" if prefix else k for k in module._tied_weights_keys] + tied_weight_keys.extend(names) + if getattr(module, "_dynamic_tied_weights_keys", None) is not None: + names = [f"{prefix}.{k}" if prefix else k for k in module._dynamic_tied_weights_keys] + tied_weight_keys.extend(names) + for name, submodule in module.named_children(): + local_prefix = f"{prefix}.{name}" if prefix else name + tied_weight_keys.extend(_get_tied_weight_keys(submodule, prefix=local_prefix)) + return tied_weight_keys + + +def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], List[str]]: + filtered_tensors = [] + for shared in tensors: + if len(shared) < 2: + filtered_tensors.append(shared) + continue + + areas = [] + for name in shared: + tensor = state_dict[name] + areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) + areas.sort() + + _, last_stop, last_name = areas[0] + filtered_tensors.append({last_name}) + for start, stop, name in areas[1:]: + if start >= last_stop: + filtered_tensors.append({name}) + else: + filtered_tensors[-1].add(name) + last_stop = stop + disjoint_tensors = [] + shared_tensors = [] + for tensors in filtered_tensors: + if len(tensors) == 1: + disjoint_tensors.append(tensors.pop()) + else: + shared_tensors.append(tensors) + return shared_tensors, disjoint_tensors + + +def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: + shared_tensors = [] + identical = [] + for shared in tensors: + if len(shared) < 2: + continue + + areas = collections.defaultdict(set) + for name in shared: + tensor = state_dict[name] + area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor)) + areas[area].add(name) + if len(areas) == 1: + identical.append(shared) + else: + shared_tensors.append(shared) + return shared_tensors, identical + + +def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): + # Convert old format to new format if needed from a PyTorch state_dict + old_keys = [] + new_keys = [] + renamed_keys = {} + renamed_gamma = {} + renamed_beta = {} + warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` " + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + # We add only the first key as an example + new_key = key.replace("gamma", "weight") + renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma + if "beta" in key: + # We add only the first key as an example + new_key = key.replace("beta", "bias") + renamed_beta[key] = new_key if not renamed_beta else renamed_beta + if new_key: + old_keys.append(key) + new_keys.append(new_key) + renamed_keys = {**renamed_gamma, **renamed_beta} + if renamed_keys: + warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" + for old_key, new_key in renamed_keys.items(): + warning_msg += f"* `{old_key}` -> `{new_key}`\n" + warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." + logger.info_once(warning_msg) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + local_metadata["assign_to_params_buffers"] = assign_to_params_buffers + + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if len([key for key in state_dict if key.startswith(prefix)]) > 0: + if is_deepspeed_zero3_enabled(): + import deepspeed + + # In sharded models, each shard has only part of the full state_dict, so only gather + # parameters that are in the current state_dict. + named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) + params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] + if len(params_to_gather) > 0: + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + else: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".", assign_to_params_buffers) + + load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers) + # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so + # it's safe to delete it. + del state_dict + + return error_msgs + + +def find_submodule_and_param_name(model, long_key, start_prefix): + """ + A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed + from the start of the key + """ + + if len(start_prefix) > 0 and long_key.startswith(start_prefix): + long_key = ".".join(long_key.split(".")[1:]) + + split_key = long_key.split(".") + submodule = model + while len(split_key) > 1: + if hasattr(submodule, split_key[0]): + submodule = getattr(submodule, split_key[0]) + del split_key[0] + else: + submodule = None + break + if submodule == model: + submodule = None + return submodule, split_key[0] + + +def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): + """ + Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params. + + `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in + `bert.pooler.dense.weight` + + """ + + # dematerialize param storage for keys that are going to be replaced by state_dict, by + # putting those on the meta device + for k in loaded_state_dict_keys: + submodule, param_name = find_submodule_and_param_name(model, k, start_prefix) + if submodule is not None: + # selectively switch to the meta device only those params/buffers that will + # be next replaced from state_dict. This a complex way to do p.to_("meta") + # since we have no in-place to_ for tensors. + new_val = getattr(submodule, param_name) + if isinstance(new_val, torch.nn.Parameter): + # isinstance returns False for Params on meta device, so switch after the check + new_val = torch.nn.Parameter(new_val.to("meta")) + else: + new_val = new_val.to("meta") + setattr(submodule, param_name, new_val) + + +def _load_state_dict_into_meta_model( + model, + state_dict, + start_prefix, + expected_keys, + device_map=None, + offload_folder=None, + offload_index=None, + state_dict_folder=None, + state_dict_index=None, + dtype=None, + hf_quantizer=None, + is_safetensors=False, + keep_in_fp32_modules=None, + unexpected_keys=None, # passing `unexpected` for cleanup from quantization items + pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys +): + """ + This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its + params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the + params back to the normal device, but only for `loaded_state_dict_keys`. + + `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in + `bert.pooler.dense.weight` + + """ + + # XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model + # - deepspeed zero 3 support + # - need to copy metadata if any - see _load_state_dict_into_model + # - handling error_msgs - mimicking the error handling in module._load_from_state_dict() + + error_msgs = [] + + old_keys = [] + new_keys = [] + renamed_gamma = {} + renamed_beta = {} + is_quantized = hf_quantizer is not None + warning_msg = f"This model {type(model)}" + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + # We add only the first key as an example + new_key = key.replace("gamma", "weight") + renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma + if "beta" in key: + # We add only the first key as an example + new_key = key.replace("beta", "bias") + renamed_beta[key] = new_key if not renamed_beta else renamed_beta + + # To reproduce `_load_state_dict_into_model` behaviour, we need to manually rename parametrized weigth norm, if necessary. + if hasattr(nn.utils.parametrizations, "weight_norm"): + if "weight_g" in key: + new_key = key.replace("weight_g", "parametrizations.weight.original0") + if "weight_v" in key: + new_key = key.replace("weight_v", "parametrizations.weight.original1") + else: + if "parametrizations.weight.original0" in key: + new_key = key.replace("parametrizations.weight.original0", "weight_g") + if "parametrizations.weight.original1" in key: + new_key = key.replace("parametrizations.weight.original1", "weight_v") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + renamed_keys = {**renamed_gamma, **renamed_beta} + if renamed_keys: + warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" + for old_key, new_key in renamed_keys.items(): + warning_msg += f"* `{old_key}` -> `{new_key}`\n" + warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." + logger.info_once(warning_msg) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + + for param_name, param in state_dict.items(): + if param_name not in expected_keys: + continue + + if param_name.startswith(start_prefix): + param_name = param_name[len(start_prefix) :] + + module_name = param_name + set_module_kwargs = {} + + # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params + # in int/uint/bool and not cast them. + is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn + if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn: + if ( + keep_in_fp32_modules is not None + and any( + module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules + ) + and dtype == torch.float16 + ): + param = param.to(torch.float32) + + # For backward compatibility with older versions of `accelerate` + # TODO: @sgugger replace this check with version check at the next `accelerate` release + if "dtype" in list(inspect.signature(set_module_tensor_to_device).parameters): + set_module_kwargs["dtype"] = torch.float32 + else: + param = param.to(dtype) + + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which + # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. + # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 + + old_param = model + splits = param_name.split(".") + for split in splits: + # We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys. + old_param = getattr(old_param, split, None) + if old_param is None: + break + + if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): + old_param = None + + if old_param is not None: + if dtype is None: + param = param.to(old_param.dtype) + + if old_param.is_contiguous(): + param = param.contiguous() + + set_module_kwargs["value"] = param + + if device_map is None: + param_device = "cpu" + else: + # find next higher level module that is defined in device_map: + # bert.lm_head.weight -> bert.lm_head -> bert -> '' + while len(module_name) > 0 and module_name not in device_map: + module_name = ".".join(module_name.split(".")[:-1]) + if module_name == "" and "" not in device_map: + # TODO: group all errors and raise at the end. + raise ValueError(f"{param_name} doesn't have any device set.") + param_device = device_map[module_name] + + if param_device == "disk": + if not is_safetensors: + offload_index = offload_weight(param, param_name, offload_folder, offload_index) + elif param_device == "cpu" and state_dict_index is not None: + state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) + elif ( + not is_quantized + or (not hf_quantizer.requires_parameters_quantization) + or ( + not hf_quantizer.check_quantized_param( + model, param, param_name, state_dict, param_device=param_device, device_map=device_map + ) + ) + ): + if is_fsdp_enabled(): + param_device = "cpu" if is_local_dist_rank_0() else "meta" + + # For backward compatibility with older versions of `accelerate` and for non-quantized params + set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) + else: + hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys) + # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU + # and then cast it to CPU to avoid excessive memory usage on each GPU + # in comparison to the sharded model across GPUs. + if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): + module, tensor_name = get_module_from_name(model, param_name) + value = getattr(module, tensor_name) + param_to = "cpu" + if is_fsdp_enabled() and not is_local_dist_rank_0(): + param_to = "meta" + value = type(value)(value.data.to(param_to), **value.__dict__) + setattr(module, tensor_name, value) + # TODO: consider removing used param_parts from state_dict before return + + return error_msgs, offload_index, state_dict_index + + +def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + + return weights_name + + +class ModuleUtilsMixin: + """ + A few utilities for `torch.nn.Modules`, to be used as a mixin. + """ + + @staticmethod + def _hook_rss_memory_pre_forward(module, *args, **kwargs): + try: + import psutil + except ImportError: + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") + + process = psutil.Process(os.getpid()) + mem = process.memory_info() + module.mem_rss_pre_forward = mem.rss + return None + + @staticmethod + def _hook_rss_memory_post_forward(module, *args, **kwargs): + try: + import psutil + except ImportError: + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") + + process = psutil.Process(os.getpid()) + mem = process.memory_info() + module.mem_rss_post_forward = mem.rss + mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward + module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0) + return None + + def add_memory_hooks(self): + """ + Add a memory hook before and after each sub-module forward pass to record increase in memory consumption. + + Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero + with `model.reset_memory_hooks_state()`. + """ + for module in self.modules(): + module.register_forward_pre_hook(self._hook_rss_memory_pre_forward) + module.register_forward_hook(self._hook_rss_memory_post_forward) + self.reset_memory_hooks_state() + + def reset_memory_hooks_state(self): + """ + Reset the `mem_rss_diff` attribute of each module (see [`~modeling_utils.ModuleUtilsMixin.add_memory_hooks`]). + """ + for module in self.modules(): + module.mem_rss_diff = 0 + module.mem_rss_post_forward = 0 + module.mem_rss_pre_forward = 0 + + @property + def device(self) -> torch.device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`torch.Tensor`): An attention mask. + + Returns: + `torch.Tensor`: The inverted attention mask. + """ + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min + + return encoder_extended_attention_mask + + @staticmethod + def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None): + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning + ) + else: + device = attention_mask.device + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + return extended_attention_mask + + def get_extended_attention_mask( + self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + if dtype is None: + dtype = self.dtype + + if not (attention_mask.dim() == 2 and self.config.is_decoder): + # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning + ) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min + return extended_attention_mask + + def get_head_mask( + self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False + ) -> Tensor: + """ + Prepare the head mask if needed. + + Args: + head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). + num_hidden_layers (`int`): + The number of hidden layers in the model. + is_attention_chunked (`bool`, *optional*, defaults to `False`): + Whether or not the attentions scores are computed by chunks or not. + + Returns: + `torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with + `[None]` for each layer. + """ + if head_mask is not None: + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + if is_attention_chunked is True: + head_mask = head_mask.unsqueeze(-1) + else: + head_mask = [None] * num_hidden_layers + + return head_mask + + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer + assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" + head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility + return head_mask + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + `int`: The number of parameters. + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding) + ] + total_parameters = [ + parameter for name, parameter in self.named_parameters() if name not in embedding_param_names + ] + else: + total_parameters = list(self.parameters()) + + total_numel = [] + is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False) + + if is_loaded_in_4bit: + if is_bitsandbytes_available(): + import bitsandbytes as bnb + else: + raise ValueError( + "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong" + " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. " + ) + + for param in total_parameters: + if param.requires_grad or not only_trainable: + # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are + # used for the 4bit quantization (uint8 tensors are stored) + if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit): + if hasattr(param, "element_size"): + num_bytes = param.element_size() + elif hasattr(param, "quant_storage"): + num_bytes = param.quant_storage.itemsize + else: + num_bytes = 1 + total_numel.append(param.numel() * 2 * num_bytes) + else: + total_numel.append(param.numel()) + + return sum(total_numel) + + def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int: + """ + Helper function to estimate the total number of tokens from the model inputs. + + Args: + inputs (`dict`): The model inputs. + + Returns: + `int`: The total number of tokens. + """ + if not hasattr(self, "warnings_issued"): + self.warnings_issued = {} + if self.main_input_name in input_dict: + return input_dict[self.main_input_name].numel() + elif "estimate_tokens" not in self.warnings_issued: + logger.warning( + "Could not estimate the number of tokens of the input, floating-point operations will not be computed" + ) + self.warnings_issued["estimate_tokens"] = True + return 0 + + def floating_point_ops( + self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True + ) -> int: + """ + Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a + batch with this transformer model. Default approximation neglects the quadratic dependency on the number of + tokens (valid if `12 * d_model << sequence_length`) as laid out in [this + paper](https://arxiv.org/pdf/2001.08361.pdf) section 2.1. Should be overridden for transformers with parameter + re-use e.g. Albert or Universal Transformers, or if doing long-range modeling with very high sequence lengths. + + Args: + batch_size (`int`): + The batch size for the forward pass. + + sequence_length (`int`): + The number of tokens in each line of the batch. + + exclude_embeddings (`bool`, *optional*, defaults to `True`): + Whether or not to count embedding and softmax operations. + + Returns: + `int`: The number of floating-point operations. + """ + + return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) + + +# TODO (joao): remove `GenerationMixin` inheritance in v4.50 +class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): + r""" + Base class for all models. + + [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading, + downloading and saving models as well as a few methods common to all models to: + + - resize the input embeddings, + - prune heads in the self-attention heads. + + Class attributes (overridden by derived classes): + + - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class + for this model architecture. + - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model, + taking as arguments: + + - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint. + - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model. + - **path** (`str`) -- A path to the TensorFlow checkpoint. + + - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived + classes of the same architecture adding modules on top of the base model. + - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization. + - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP + models, `pixel_values` for vision models and `input_values` for speech models). + """ + + config_class = None + base_model_prefix = "" + main_input_name = "input_ids" + model_tags = None + + _auto_class = None + _no_split_modules = None + _skip_keys_device_placement = None + _keep_in_fp32_modules = None + + # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing + # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings. + _keys_to_ignore_on_load_missing = None + # a list of `re` patterns of `state_dict` keys that should be removed from the list of + # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary + # warnings. + _keys_to_ignore_on_load_unexpected = None + # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't + # trained, but which are either deterministic or tied variables) + _keys_to_ignore_on_save = None + # a list of `state_dict` keys that are potentially tied to another key in the state_dict. + _tied_weights_keys = None + + is_parallelizable = False + supports_gradient_checkpointing = False + _is_stateful = False + + # Flash Attention 2 support + _supports_flash_attn_2 = False + + # SDPA support + _supports_sdpa = False + + # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`? + _supports_cache_class = False + _supports_static_cache = False + + # Has support for a `QuantoQuantizedCache` instance as `past_key_values` + _supports_quantized_cache = False + + @property + def dummy_inputs(self) -> Dict[str, torch.Tensor]: + """ + `Dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network. + """ + return {"input_ids": torch.tensor(DUMMY_INPUTS)} + + @property + def framework(self) -> str: + """ + :str: Identifies that this is a PyTorch model. + """ + return "pt" + + def __init__(self, config: PretrainedConfig, *inputs, **kwargs): + super().__init__() + if not isinstance(config, PretrainedConfig): + raise ValueError( + f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " + "`PretrainedConfig`. To create a model from a pretrained model use " + f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + # Save config and origin of the pretrained weights if given in model + if not getattr(config, "_attn_implementation_autoset", False): + config = self._autoset_attn_implementation( + config, torch_dtype=torch.get_default_dtype(), check_device_map=False + ) + self.config = config + + self.name_or_path = config.name_or_path + self.warnings_issued = {} + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + # Overwrite the class attribute to make it an instance attribute, so models like + # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute + # when a different component (e.g. language_model) is used. + self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules) + + def post_init(self): + """ + A method executed at the end of each Transformer model initialization, to execute code that needs the model's + modules properly initialized (such as weight initialization). + """ + self.init_weights() + self._backward_compatibility_gradient_checkpointing() + + def dequantize(self): + """ + Potentially dequantize the model in case it has been quantized by a quantization method that support + dequantization. + """ + hf_quantizer = getattr(self, "hf_quantizer", None) + + if hf_quantizer is None: + raise ValueError("You need to first quantize your model in order to dequantize it") + + return hf_quantizer.dequantize(self) + + def _backward_compatibility_gradient_checkpointing(self): + if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False): + self.gradient_checkpointing_enable() + # Remove the attribute now that is has been consumed, so it's no saved in the config. + delattr(self.config, "gradient_checkpointing") + + def add_model_tags(self, tags: Union[List[str], str]) -> None: + r""" + Add custom tags into the model that gets pushed to the Hugging Face Hub. Will + not overwrite existing tags in the model. + + Args: + tags (`Union[List[str], str]`): + The desired tags to inject in the model + + Examples: + + ```python + from transformers import AutoModel + + model = AutoModel.from_pretrained("google-bert/bert-base-cased") + + model.add_model_tags(["custom", "custom-bert"]) + + # Push the model to your namespace with the name "my-custom-bert". + model.push_to_hub("my-custom-bert") + ``` + """ + if isinstance(tags, str): + tags = [tags] + + if self.model_tags is None: + self.model_tags = [] + + for tag in tags: + if tag not in self.model_tags: + self.model_tags.append(tag) + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + + Args: + torch_dtype (`torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. + """ + # when we init a model from within another model (e.g. VLMs) and dispatch on FA2 + # a warning is raised that dtype should be fp16. Since we never pass dtype from within + # modeling code, we can try to infer it here same way as done in `from_pretrained` + torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype()) + use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) + + # override default dtype if needed + dtype_orig = None + if torch_dtype is not None: + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + + config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. + + if config._attn_implementation_internal is not None: + # In this case, the config has been created with the attn_implementation set by the user, which we + # should respect. + attn_implementation = config._attn_implementation_internal + else: + attn_implementation = None + + config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation) + if not getattr(config, "_attn_implementation_autoset", False): + config = cls._autoset_attn_implementation( + config, + use_flash_attention_2=use_flash_attention_2, + check_device_map=False, + torch_dtype=torch_dtype, + ) + + if is_deepspeed_zero3_enabled(): + import deepspeed + + logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") + # this immediately partitions the model across all gpus, to avoid the overhead in time + # and memory copying it on CPU or each GPU first + with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()): + model = cls(config, **kwargs) + + else: + model = cls(config, **kwargs) + + # restore default dtype if it was modified + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + return model + + @classmethod + def _autoset_attn_implementation( + cls, + config, + use_flash_attention_2: bool = False, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + ): + """ + Automatically checks and dispatches to a default attention implementation. In order of priority: + 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained). + 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) + 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) + 4. The default model's implementation otherwise (`LlamaAttention` for example) . + """ + # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. + # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). + # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) + requested_attn_implementation = None + if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: + if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: + raise ValueError( + f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.' + ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' + ) + + if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [ + "eager", + "sdpa", + "flash_attention_2", + ]: + message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' + if cls._supports_flash_attn_2: + message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + if cls._supports_sdpa: + message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' + raise ValueError(message + ".") + + # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. + requested_attn_implementation = config._attn_implementation_internal + + # Composite models consisting of several PretrainedModels have to specify attention impl as a dict + # where keys are sub-config names. But most people will specify one `str` which means that should dispatch it + # for all sub-models. + # Below we check if a config is composite and manually prepare a dict of attn impl if not already passed as a dict. + # Later each sub-module will dispatch with its own attn impl, by calling `XXXModel._from_config(config.text_config)` + # If any of sub-modules doesn't support requested attn, an error will be raised. See https://github.com/huggingface/transformers/pull/32238 + for key in config: + if isinstance(getattr(config, key), PretrainedConfig): + sub_config = getattr(config, key) + curr_attn_implementation = ( + requested_attn_implementation + if not isinstance(requested_attn_implementation, dict) + else requested_attn_implementation.get(key, None) + ) + sub_config._attn_implementation_internal = curr_attn_implementation + + if use_flash_attention_2: + logger.warning_once( + 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.' + ) + config._attn_implementation = "flash_attention_2" + + if config._attn_implementation == "flash_attention_2": + cls._check_and_enable_flash_attn_2( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) + elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): + # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. + config = cls._check_and_enable_sdpa( + config, + hard_check_only=False if requested_attn_implementation is None else True, + ) + + if ( + torch.version.hip is not None + and config._attn_implementation == "sdpa" + and torch.cuda.device_count() > 1 + ): + logger.warning_once( + "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." + ) + torch.backends.cuda.enable_flash_sdp(False) + elif isinstance(requested_attn_implementation, dict): + config._attn_implementation = None + else: + config._attn_implementation = "eager" + + config._attn_implementation_autoset = True + return config + + @classmethod + def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype: + """ + Change the default dtype and return the previous one. This is needed when wanting to instantiate the model + under specific dtype. + + Args: + dtype (`torch.dtype`): + a floating dtype to set to. + + Returns: + `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was + modified. If it wasn't, returns `None`. + + Note `set_default_dtype` currently only works with floating-point types and asserts if for example, + `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception. + """ + if not dtype.is_floating_point: + raise ValueError( + f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype" + ) + + logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.") + dtype_orig = torch.get_default_dtype() + torch.set_default_dtype(dtype) + return dtype_orig + + @property + def base_model(self) -> nn.Module: + """ + `torch.nn.Module`: The main body of the model. + """ + return getattr(self, self.base_model_prefix, self) + + @classmethod + def can_generate(cls) -> bool: + """ + Returns whether this model can generate sequences with `.generate()`. + + Returns: + `bool`: Whether this model can generate sequences with `.generate()`. + """ + # Directly inherits `GenerationMixin` -> can generate + if "GenerationMixin" in str(cls.__bases__): + return True + # Model class overwrites `generate` (e.g. time series models) -> can generate + if str(cls.__name__) in str(cls.generate): + return True + # The class inherits from a class that can generate (recursive check) -> can generate + for base in cls.__bases__: + if not hasattr(base, "can_generate"): + continue + if "PreTrainedModel" not in str(base) and base.can_generate(): + return True + # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this + # was how we detected whether a model could generate. + if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): + logger.warning_once( + f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly " + "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, " + "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability " + "to call `generate` and other related functions." + "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the " + "model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes" + "\n - If you are the owner of the model architecture code, please modify your model class such that " + "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)." + "\n - If you are not the owner of the model architecture class, please contact the model code owner " + "to update it." + ) + return True + # Otherwise, can't generate + return False + + @classmethod + def _check_and_enable_flash_attn_2( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + hard_check_only: bool = False, + ) -> PretrainedConfig: + """ + Checks the availability of Flash Attention 2 and compatibility with the current model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. + """ + if not cls._supports_flash_attn_2: + raise ValueError( + f"{cls.__name__} does not support Flash Attention 2.0 yet. Please request to add support where" + f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new" + " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" + ) + + if not is_flash_attn_2_available(): + preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:" + install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." + + if importlib.util.find_spec("flash_attn") is None: + raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") + + flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) + if torch.version.cuda: + if flash_attention_version < version.parse("2.1.0"): + raise ImportError( + f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}" + ) + elif not torch.cuda.is_available(): + raise ValueError( + f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device." + ) + else: + raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") + elif torch.version.hip: + if flash_attention_version < version.parse("2.0.4"): + raise ImportError( + f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}" + ) + else: + raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + + if _is_bettertransformer: + raise ValueError( + "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" + ) + + if torch_dtype is None: + logger.warning_once( + "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour" + ) + elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: + logger.warning_once( + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but" + f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," + ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`' + ) + + # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, + # or the model may be initialized under the context manager `with torch.device("cuda"):`. + if check_device_map and device_map is None and torch.empty(0).device.type != "cuda": + if torch.cuda.is_available(): + logger.warning_once( + "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU" + " after initializing it on CPU with `model.to('cuda')`." + ) + else: + raise ValueError( + "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. " + "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " + "or initialising the model on CPU and then moving it to GPU." + ) + elif ( + check_device_map + and device_map is not None + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + raise ValueError( + "You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " + "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." + ) + if not hard_check_only: + config._attn_implementation = "flash_attention_2" + return config + + @classmethod + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: + """ + Checks the availability of SDPA for a given model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. + """ + if hard_check_only: + if not cls._supports_sdpa: + raise ValueError( + f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet." + " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe" + ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' + ) + if not is_torch_sdpa_available(): + raise ImportError( + "PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1." + ) + + if not is_torch_sdpa_available() or not cls._supports_sdpa: + return config + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + if _is_bettertransformer: + return config + + if not hard_check_only: + config._attn_implementation = "sdpa" + return config + + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping + the model weights fixed. + """ + + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + + def disable_input_require_grads(self): + """ + Removes the `_require_grads_hook`. + """ + self._require_grads_hook.remove() + + def get_input_embeddings(self) -> nn.Module: + """ + Returns the model's input embeddings. + + Returns: + `nn.Module`: A torch module mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.get_input_embeddings() + else: + raise NotImplementedError + + def set_input_embeddings(self, value: nn.Module): + """ + Set model's input embeddings. + + Args: + value (`nn.Module`): A module mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + base_model.set_input_embeddings(value) + else: + raise NotImplementedError + + def get_output_embeddings(self) -> nn.Module: + """ + Returns the model's output embeddings. + + Returns: + `nn.Module`: A torch module mapping hidden states to vocabulary. + """ + return None # Overwrite for models with output embeddings + + def _init_weights(self, module): + """ + Initialize the weights. This method should be overridden by derived class and is + the only initialization method that will be called when loading a checkpoint + using `from_pretrained`. Any attempt to initialize outside of this function + will be useless as the torch.nn.init function are all replaced with skip. + """ + pass + + def _initialize_weights(self, module): + """ + Initialize the weights if they are not already initialized. + """ + if getattr(module, "_is_hf_initialized", False): + return + self._init_weights(module) + module._is_hf_initialized = True + + def tie_weights(self): + """ + Tie the weights between the input embeddings and the output embeddings. + + If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the + weights instead. + """ + if getattr(self.config, "tie_word_embeddings", True): + output_embeddings = self.get_output_embeddings() + if output_embeddings is not None: + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + + if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): + if hasattr(self, self.base_model_prefix): + self = getattr(self, self.base_model_prefix) + tied_weights = self._tie_encoder_decoder_weights( + self.encoder, self.decoder, self.base_model_prefix, "encoder" + ) + # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class + # attributed not an instance member, therefore modifying it will modify the entire class + # Leading to issues on subsequent calls by different tests or subsequent calls. + self._dynamic_tied_weights_keys = tied_weights + + for module in self.modules(): + if hasattr(module, "_tie_weights"): + module._tie_weights() + + @staticmethod + def _tie_encoder_decoder_weights( + encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, base_encoder_name: str + ): + uninitialized_encoder_weights: List[str] = [] + tied_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + logger.info( + f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder" + " weights are correctly initialized." + ) + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Module, + encoder_pointer: nn.Module, + module_name: str, + base_encoder_name: str, + uninitialized_encoder_weights: List[str], + depth=0, + total_decoder_name="", + total_encoder_name="", + ): + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module" + if hasattr(decoder_pointer, "weight"): + assert hasattr(encoder_pointer, "weight") + encoder_pointer.weight = decoder_pointer.weight + tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight") + if hasattr(decoder_pointer, "bias"): + assert hasattr(encoder_pointer, "bias") + tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias") + encoder_pointer.bias = decoder_pointer.bias + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert ( + len(encoder_modules) > 0 + ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" + + all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()} + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len( + encoder_modules + ) != len(decoder_modules): + # this can happen if the name corresponds to the position in a list module list of layers + # in this case the decoder has added a cross-attention that the encoder does not have + # thus skip this step and subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is" + " a circular dependency between two or more `nn.Modules` of your model." + ) + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + "/" + name, + base_encoder_name, + uninitialized_encoder_weights, + depth=depth + 1, + total_encoder_name=f"{total_encoder_name}.{encoder_name}", + total_decoder_name=f"{total_decoder_name}.{decoder_name}", + ) + all_encoder_weights.remove(module_name + "/" + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively( + decoder, encoder, base_model_prefix, base_encoder_name, uninitialized_encoder_weights + ) + + if len(uninitialized_encoder_weights) > 0: + logger.warning( + f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" + ) + return tied_weights + + def _tie_or_clone_weights(self, output_embeddings, input_embeddings): + """Tie or clone module weights depending of whether we are using TorchScript or not""" + if self.config.torchscript: + output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone()) + else: + output_embeddings.weight = input_embeddings.weight + + if getattr(output_embeddings, "bias", None) is not None: + output_embeddings.bias.data = nn.functional.pad( + output_embeddings.bias.data, + ( + 0, + output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0], + ), + "constant", + 0, + ) + if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): + output_embeddings.out_features = input_embeddings.num_embeddings + + def _get_no_split_modules(self, device_map: str): + """ + Get the modules of the model that should not be spit when using device_map. We iterate through the modules to + get the underlying `_no_split_modules`. + + Args: + device_map (`str`): + The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"] + + Returns: + `List[str]`: List of modules that should not be split + """ + _no_split_modules = set() + modules_to_check = [self] + while len(modules_to_check) > 0: + module = modules_to_check.pop(-1) + # if the module does not appear in _no_split_modules, we also check the children + if module.__class__.__name__ not in _no_split_modules: + if isinstance(module, PreTrainedModel): + if module._no_split_modules is None: + raise ValueError( + f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " + "class needs to implement the `_no_split_modules` attribute." + ) + else: + _no_split_modules = _no_split_modules | set(module._no_split_modules) + modules_to_check += list(module.children()) + return list(_no_split_modules) + + def resize_token_embeddings( + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. + + Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + + Arguments: + new_num_tokens (`int`, *optional*): + The new number of tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to + `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more + details about this, or help on choosing the correct value for resizing, refer to this guide: + https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + mean_resizing (`bool`): + Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and + covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`. + + Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models, + where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the + old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings. + Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html + + Return: + `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. + """ + model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + if new_num_tokens is None and pad_to_multiple_of is None: + return model_embeds + + # Since we are basically resuing the same old embeddings with new weight values, gathering is required + is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None): + vocab_size = model_embeds.weight.shape[0] + else: + vocab_size = model_embeds.weight.shape[0] + + # Update base model and current model config. + self.config.get_text_config().vocab_size = vocab_size + self.vocab_size = vocab_size + + # Tie weights again if needed + self.tie_weights() + + return model_embeds + + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True): + old_embeddings = self.get_input_embeddings() + new_embeddings = self._get_resized_embeddings( + old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing + ) + if hasattr(old_embeddings, "_hf_hook"): + hook = old_embeddings._hf_hook + add_hook_to_module(new_embeddings, hook) + old_embeddings_requires_grad = old_embeddings.weight.requires_grad + new_embeddings.requires_grad_(old_embeddings_requires_grad) + self.set_input_embeddings(new_embeddings) + is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None + + # Update new_num_tokens with the actual size of new_embeddings + if pad_to_multiple_of is not None: + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None): + new_num_tokens = new_embeddings.weight.shape[0] + else: + new_num_tokens = new_embeddings.weight.shape[0] + + # if word embeddings are not tied, make sure that lm head is resized as well + if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: + old_lm_head = self.get_output_embeddings() + if isinstance(old_lm_head, torch.nn.Embedding): + new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing) + else: + new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens, mean_resizing=mean_resizing) + if hasattr(old_lm_head, "_hf_hook"): + hook = old_lm_head._hf_hook + add_hook_to_module(new_lm_head, hook) + old_lm_head_requires_grad = old_lm_head.weight.requires_grad + new_lm_head.requires_grad_(old_lm_head_requires_grad) + self.set_output_embeddings(new_lm_head) + + return self.get_input_embeddings() + + def _get_resized_embeddings( + self, + old_embeddings: nn.Embedding, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + mean_resizing: bool = True, + ) -> nn.Embedding: + """ + Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly + initialized vectors at the end. Reducing the size will remove vectors from the end + + Args: + old_embeddings (`torch.nn.Embedding`): + Old embeddings to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the embedding matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns a pointer to the input tokens + `torch.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to + `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more + details about this, or help on choosing the correct value for resizing, refer to this guide: + https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + mean_resizing (`bool`): + Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and + covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`. + + Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models, + where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the + old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings. + Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html + + + Return: + `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if + `new_num_tokens` is `None` + """ + + if pad_to_multiple_of is not None: + if not isinstance(pad_to_multiple_of, int): + raise ValueError( + f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer" + ) + if new_num_tokens is None: + new_num_tokens = old_embeddings.weight.shape[0] + new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of + else: + logger.info( + "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding" + f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available." + " For more details about this, or help on choosing the correct value for resizing, refer to this guide:" + " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc" + ) + + if new_num_tokens is None: + return old_embeddings + + is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None): + old_num_tokens, old_embedding_dim = old_embeddings.weight.size() + else: + old_num_tokens, old_embedding_dim = old_embeddings.weight.size() + + if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled(): + return old_embeddings + + if not isinstance(old_embeddings, nn.Embedding): + raise TypeError( + f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You" + " should either use a different resize function or make sure that `old_embeddings` are an instance of" + f" {nn.Embedding}." + ) + + # Build new embeddings + + # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init + # because the shape of the new embedding layer is used across various modeling files + # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading + # to errors when training. + new_embeddings = nn.Embedding( + new_num_tokens, + old_embedding_dim, + device=old_embeddings.weight.device, + dtype=old_embeddings.weight.dtype, + ) + + if new_num_tokens > old_num_tokens and not mean_resizing: + # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`. + self._init_weights(new_embeddings) + + elif new_num_tokens > old_num_tokens and mean_resizing: + # initialize new embeddings (in particular added tokens). The new embeddings will be initialized + # from a multivariate normal distribution that has old embeddings' mean and covariance. + # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html + logger.warning_once( + "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. " + "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. " + "To disable this, use `mean_resizing=False`" + ) + + added_num_tokens = new_num_tokens - old_num_tokens + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None): + self._init_added_embeddings_weights_with_mean( + old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens + ) + else: + self._init_added_embeddings_weights_with_mean( + old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens + ) + + # Copy token embeddings from the previous weights + + # numbers of tokens to copy + n = min(old_num_tokens, new_num_tokens) + + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + params = [old_embeddings.weight, new_embeddings.weight] + with deepspeed.zero.GatheredParameters(params, modifier_rank=0): + new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] + else: + new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] + + # Replace weights in old_embeddings and return to maintain the same embedding type. + # This ensures correct functionality when a Custom Embedding class is passed as input. + # The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979) + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + params = [old_embeddings.weight, new_embeddings.weight] + with deepspeed.zero.GatheredParameters(params, modifier_rank=0): + old_embeddings.weight = new_embeddings.weight + old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0] + + # If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx` + # will be set to `None` in the resized embeddings. + if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx: + old_embeddings.padding_idx = None + else: + old_embeddings.weight.data = new_embeddings.weight.data + old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0] + if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx: + old_embeddings.padding_idx = None + + return old_embeddings + + def _get_resized_lm_head( + self, + old_lm_head: nn.Linear, + new_num_tokens: Optional[int] = None, + transposed: Optional[bool] = False, + mean_resizing: bool = True, + ) -> nn.Linear: + """ + Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end + + Args: + old_lm_head (`torch.nn.Linear`): + Old lm head liner layer to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the linear matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns a pointer to the input tokens + `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults + to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim, + vocab_size` else `vocab_size, lm_head_dim`. + mean_resizing (`bool`): + Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and + covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`. + + Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models, + where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the + old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings. + Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html + + Return: + `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is + `None` + """ + if new_num_tokens is None: + return old_lm_head + + is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None): + old_num_tokens, old_lm_head_dim = ( + old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() + ) + else: + old_num_tokens, old_lm_head_dim = ( + old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() + ) + + if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled(): + return old_lm_head + + if not isinstance(old_lm_head, nn.Linear): + raise TypeError( + f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You" + " should either use a different resize function or make sure that `old_lm_head` are an instance of" + f" {nn.Linear}." + ) + + # Build new lm head + new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) + has_new_lm_head_bias = old_lm_head.bias is not None + + # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init + # because the shape of the new embedding layer is used across various modeling files + # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading + # to errors when training. + new_lm_head = nn.Linear( + *new_lm_head_shape, + bias=has_new_lm_head_bias, + device=old_lm_head.weight.device, + dtype=old_lm_head.weight.dtype, + ) + + if new_num_tokens > old_num_tokens and not mean_resizing: + # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`. + self._init_weights(new_lm_head) + + elif new_num_tokens > old_num_tokens and mean_resizing: + # initialize new lm_head weights (in particular added tokens). The new lm_head weights + # will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. + # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html + logger.warning_once( + "The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. " + "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. " + "To disable this, use `mean_resizing=False`" + ) + + added_num_tokens = new_num_tokens - old_num_tokens + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + params = [old_lm_head.weight] + if has_new_lm_head_bias: + params += [old_lm_head.bias] + with deepspeed.zero.GatheredParameters(params, modifier_rank=None): + self._init_added_lm_head_weights_with_mean( + old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed + ) + if has_new_lm_head_bias: + self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens) + + else: + self._init_added_lm_head_weights_with_mean( + old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed + ) + if has_new_lm_head_bias: + self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens) + + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias] + with deepspeed.zero.GatheredParameters(params, modifier_rank=0): + self._copy_lm_head_original_to_resized( + new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ) + else: + self._copy_lm_head_original_to_resized( + new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ) + + return new_lm_head + + def _init_added_embeddings_weights_with_mean( + self, old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens + ): + old_embeddings_weight = old_embeddings.weight.data.to(torch.float32) + mean_embeddings = torch.mean(old_embeddings_weight, axis=0) + old_centered_embeddings = old_embeddings_weight - mean_embeddings + covariance = old_centered_embeddings.T @ old_centered_embeddings / old_num_tokens + + # Check if the covariance is positive definite. + eigenvalues = torch.linalg.eigvals(covariance) + is_covariance_psd = bool( + (covariance == covariance.T).all() and not torch.is_complex(eigenvalues) and (eigenvalues > 0).all() + ) + if is_covariance_psd: + # If covariances is positive definite, a distribution can be created. and we can sample new weights from it. + distribution = torch.distributions.multivariate_normal.MultivariateNormal( + mean_embeddings, covariance_matrix=1e-9 * covariance + ) + new_embeddings.weight.data[-1 * added_num_tokens :, :] = distribution.sample( + sample_shape=(added_num_tokens,) + ).to(old_embeddings.weight.dtype) + else: + # Otherwise, just initialize with the mean. because distribtion will not be created. + new_embeddings.weight.data[-1 * added_num_tokens :, :] = ( + mean_embeddings[None, :].repeat(added_num_tokens, 1).to(old_embeddings.weight.dtype) + ) + + def _init_added_lm_head_weights_with_mean( + self, + old_lm_head, + new_lm_head, + old_lm_head_dim, + old_num_tokens, + added_num_tokens, + transposed=False, + ): + if transposed: + # Transpose to the desired shape for the function. + new_lm_head.weight.data = new_lm_head.weight.data.T + old_lm_head.weight.data = old_lm_head.weight.data.T + + # The same initilization logic as Embeddings. + self._init_added_embeddings_weights_with_mean( + old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens + ) + + if transposed: + # Transpose again to the correct shape. + new_lm_head.weight.data = new_lm_head.weight.data.T + old_lm_head.weight.data = old_lm_head.weight.data.T + + def _init_added_lm_head_bias_with_mean(self, old_lm_head, new_lm_head, added_num_tokens): + bias_mean = torch.mean(old_lm_head.bias.data, axis=0, dtype=torch.float32) + bias_std = torch.std(old_lm_head.bias.data, axis=0).to(torch.float32) + new_lm_head.bias.data[-1 * added_num_tokens :].normal_(mean=bias_mean, std=1e-9 * bias_std) + + def _copy_lm_head_original_to_resized( + self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ): + # Copy old lm head weights to new lm head + if not transposed: + new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] + else: + new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy] + + # Copy bias weights to new lm head + if has_new_lm_head_bias: + new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] + + def resize_position_embeddings(self, new_num_position_embeddings: int): + raise NotImplementedError( + f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " + f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`" + ) + + def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]: + raise NotImplementedError( + f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " + f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`" + ) + + def init_weights(self): + """ + If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any + initialization logic in `_init_weights`. + """ + # Prune heads if needed + if self.config.pruned_heads: + self.prune_heads(self.config.pruned_heads) + + if _init_weights: + # Initialize weights + self.apply(self._initialize_weights) + + # Tie weights should be skipped when not initializing all weights + # since from_pretrained(...) calls tie weights anyways + self.tie_weights() + + def prune_heads(self, heads_to_prune: Dict[int, List[int]]): + """ + Prunes heads of the base model. + + Arguments: + heads_to_prune (`Dict[int, List[int]]`): + Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads + to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on + layer 1 and heads 2 and 3 on layer 2. + """ + # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads + for layer, heads in heads_to_prune.items(): + union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) + self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON + + self.base_model._prune_heads(heads_to_prune) + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. + """ + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": True} + + gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) + + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` method + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) + else: + self.apply(partial(self._set_gradient_checkpointing, value=True)) + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + + def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint): + is_gradient_checkpointing_set = False + + # Apply it on the top-level module in case the top-level modules supports it + # for example, LongT5Stack inherits from `PreTrainedModel`. + if hasattr(self, "gradient_checkpointing"): + self._gradient_checkpointing_func = gradient_checkpointing_func + self.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + for module in self.modules(): + if hasattr(module, "gradient_checkpointing"): + module._gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + if not is_gradient_checkpointing_set: + raise ValueError( + f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute" + " `gradient_checkpointing` to modules of the model that uses checkpointing." + ) + + def gradient_checkpointing_disable(self): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self.supports_gradient_checkpointing: + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` methid + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=False) + else: + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + if getattr(self, "_hf_peft_config_loaded", False): + self.disable_input_require_grads() + + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = torch.save, + push_to_hub: bool = False, + max_shard_size: Union[int, str] = "5GB", + safe_serialization: bool = True, + variant: Optional[str] = None, + token: Optional[Union[str, bool]] = None, + save_peft_format: bool = True, + **kwargs, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~PreTrainedModel.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + state_dict (nested dictionary of `torch.Tensor`): + The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only + save parts of the model or if special precautions need to be taken when recovering the state dictionary + of a model (like when using model parallelism). + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + We default it to 5GB in order for models to be able to run easily on free-tier google colab instances + without CPU OOM issues. + + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + save_peft_format (`bool`, *optional*, defaults to `True`): + For backward compatibility with PEFT library, in case adapter weights are attached to the model, all + keys of the state dict of adapters needs to be pre-pended with `base_model.model`. Advanced users can + disable this behaviours by setting `save_peft_format` to `False`. + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + ignore_metadata_errors = kwargs.pop("ignore_metadata_errors", False) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False) + + hf_quantizer = getattr(self, "hf_quantizer", None) + quantization_serializable = ( + hf_quantizer is not None + and isinstance(hf_quantizer, HfQuantizer) + and hf_quantizer.is_serializable(safe_serialization=safe_serialization) + ) + + if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable: + raise ValueError( + f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" + " the logger on the traceback to understand the reason why the quantized model is not serializable." + ) + + if "save_config" in kwargs: + warnings.warn( + "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead." + ) + is_main_process = kwargs.pop("save_config") + if safe_serialization and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # Only save the model itself if we are using distributed training + model_to_save = unwrap_model(self) + + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + # we currently don't use this setting automatically, but may start to use with v5 + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.torch_dtype = str(dtype).split(".")[1] + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # Unset attn implementation so it can be set to another one when loading back + model_to_save.config._attn_implementation_autoset = False + + # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self.config) + + # Save the config + if is_main_process: + if not _hf_peft_config_loaded: + # If the model config has set attributes that should be in the generation config, move them there. + misplaced_generation_parameters = model_to_save.config._get_non_default_generation_parameters() + if self.can_generate() and len(misplaced_generation_parameters) > 0: + warnings.warn( + "Moving the following attributes in the config to the generation config: " + f"{misplaced_generation_parameters}. You are seeing this warning because you've set " + "generation parameters in the model config, as opposed to in the generation config.", + UserWarning, + ) + for param_name, param_value in misplaced_generation_parameters.items(): + setattr(model_to_save.generation_config, param_name, param_value) + setattr(model_to_save.config, param_name, None) + + model_to_save.config.save_pretrained(save_directory) + if self.can_generate(): + model_to_save.generation_config.save_pretrained(save_directory) + + if _hf_peft_config_loaded: + logger.info( + "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved." + ) + state_dict = model_to_save.get_adapter_state_dict() + + if save_peft_format: + logger.info( + "To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`." + ) + peft_state_dict = {} + for key, value in state_dict.items(): + peft_state_dict[f"base_model.model.{key}"] = value + state_dict = peft_state_dict + + active_adapter = self.active_adapters() + + if len(active_adapter) > 1: + raise ValueError( + "Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one " + "by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`" + ) + active_adapter = active_adapter[0] + + current_peft_config = self.peft_config[active_adapter] + current_peft_config.save_pretrained(save_directory) + + # for offloaded modules + module_map = {} + + # Save the model + if state_dict is None: + # if any model parameters are offloaded, make module map + if ( + hasattr(self, "hf_device_map") + and len(set(self.hf_device_map.values())) > 1 + and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values()) + ): + warnings.warn( + "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)" + ) + for name, module in model_to_save.named_modules(): + if name == "": + continue + module_state_dict = module.state_dict() + + for key in module_state_dict: + module_map[name + f".{key}"] = module + state_dict = model_to_save.state_dict() + + # Translate state_dict from smp to hf if saving with smp >= 1.10 + if IS_SAGEMAKER_MP_POST_1_10: + for smp_to_hf, _ in smp.state.module_manager.translate_functions: + state_dict = smp_to_hf(state_dict) + + # Handle the case where some state_dict keys shouldn't be saved + if self._keys_to_ignore_on_save is not None: + for ignore_key in self._keys_to_ignore_on_save: + if ignore_key in state_dict.keys(): + del state_dict[ignore_key] + if safe_serialization: + # Safetensors does not allow tensor aliasing. + # We're going to remove aliases before saving + ptrs = collections.defaultdict(list) + for name, tensor in state_dict.items(): + # Sometimes in the state_dict we have non-tensor objects. + # e.g. in bitsandbytes we have some `str` objects in the state_dict + if isinstance(tensor, torch.Tensor): + ptrs[id_tensor_storage(tensor)].append(name) + else: + # In the non-tensor case, fall back to the pointer of the object itself + ptrs[id(tensor)].append(name) + + # These are all the pointers of shared tensors + if hasattr(self, "hf_device_map"): + # if the model has offloaded parameters, we must check using find_tied_parameters() + tied_params = find_tied_parameters(self) + if tied_params: + tied_names = tied_params[0] + shared_ptrs = { + ptr: names for ptr, names in ptrs.items() if any(name in tied_names for name in names) + } + else: + shared_ptrs = {} + else: + shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} + + # Recursively descend to find tied weight keys + _tied_weights_keys = _get_tied_weight_keys(self) + error_names = [] + to_delete_names = set() + for names in shared_ptrs.values(): + # Removing the keys which are declared as known duplicates on + # load. This allows to make sure the name which is kept is consistent. + if _tied_weights_keys is not None: + found = 0 + for name in sorted(names): + matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys) + if matches_pattern and name in state_dict: + found += 1 + if found < len(names): + to_delete_names.add(name) + # We are entering a place where the weights and the transformers configuration do NOT match. + shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict) + # Those are actually tensor sharing but disjoint from each other, we can safely clone them + # Reloaded won't have the same property, but it shouldn't matter in any meaningful way. + for name in disjoint_names: + state_dict[name] = state_dict[name].clone() + + # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. + # If the link between tensors was done at runtime then `from_pretrained` will not get + # the key back leading to random tensor. A proper warning will be shown + # during reload (if applicable), but since the file is not necessarily compatible with + # the config, better show a proper warning. + shared_names, identical_names = _find_identical(shared_names, state_dict) + # delete tensors that have identical storage + for inames in identical_names: + known = inames.intersection(to_delete_names) + for name in known: + del state_dict[name] + unknown = inames.difference(to_delete_names) + if len(unknown) > 1: + error_names.append(unknown) + + if shared_names: + error_names.append(set(shared_names)) + + if len(error_names) > 0: + raise RuntimeError( + f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.", + ) + + # Shard the model if it is too big. + if not _hf_peft_config_loaded: + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + else: + weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME + + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + # Save index if sharded + index = None + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + reg = re.compile(r"(.*?)-\d{5}-of-\d{5}") + + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in state_dict_split.filename_to_tensors.keys() + and is_main_process + and reg.fullmatch(filename_no_suffix) is not None + ): + os.remove(full_filename) + # Save the model + filename_to_tensors = state_dict_split.filename_to_tensors.items() + if module_map: + filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards") + for shard_file, tensors in filename_to_tensors: + shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} + # remake shard with onloaded parameters if necessary + if module_map: + if accelerate_version < version.parse("0.31"): + raise ImportError( + f"You need accelerate version to be greater or equal than 0.31 to save models with offloaded parameters. Detected version {accelerate_version}. " + f"Please upgrade accelerate with `pip install -U accelerate`" + ) + # init state_dict for this shard + shard_state_dict = {name: "" for name in shard} + for module_name in shard: + module = module_map[module_name] + # update state dict with onloaded parameters + shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict) + + # assign shard to be the completed state dict + shard = shard_state_dict + del shard_state_dict + gc.collect() + + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}) + else: + save_function(shard, os.path.join(save_directory, shard_file)) + + if index is None: + path_to_weights = os.path.join(save_directory, weights_name) + logger.info(f"Model weights saved in {path_to_weights}") + else: + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + if push_to_hub: + # Eventually create an empty model card + model_card = create_and_tag_model_card( + repo_id, self.model_tags, token=token, ignore_metadata_errors=ignore_metadata_errors + ) + + # Update model card if needed: + model_card.save(os.path.join(save_directory, "README.md")) + + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=token, + ) + + @wraps(PushToHubMixin.push_to_hub) + def push_to_hub(self, *args, **kwargs): + tags = self.model_tags if self.model_tags is not None else [] + + tags_kwargs = kwargs.get("tags", []) + if isinstance(tags_kwargs, str): + tags_kwargs = [tags_kwargs] + + for tag in tags_kwargs: + if tag not in tags: + tags.append(tag) + + if tags: + kwargs["tags"] = tags + return super().push_to_hub(*args, **kwargs) + + def get_memory_footprint(self, return_buffers=True): + r""" + Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. + Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the + PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 + + Arguments: + return_buffers (`bool`, *optional*, defaults to `True`): + Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers + are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch + norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 + """ + mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) + if return_buffers: + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem + mem_bufs + return mem + + @wraps(torch.nn.Module.cuda) + def cuda(self, *args, **kwargs): + if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: + raise ValueError("`.cuda` is not supported for HQQ-quantized models.") + # Checks if the model has been loaded in 4-bit or 8-bit with BNB + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "Calling `cuda()` is not supported for `8-bit` quantized models. " + " Please use the model as it is, since the model has already been set to the correct devices." + ) + elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"): + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + else: + return super().cuda(*args, **kwargs) + + @wraps(torch.nn.Module.to) + def to(self, *args, **kwargs): + # For BNB/GPTQ models, we prevent users from casting the model to another dtype to restrict unwanted behaviours. + # the correct API should be to load the model with the desired dtype directly through `from_pretrained`. + dtype_present_in_args = "dtype" in kwargs + + if not dtype_present_in_args: + for arg in args: + if isinstance(arg, torch.dtype): + dtype_present_in_args = True + break + + if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ: + raise ValueError("`.to` is not supported for HQQ-quantized models.") + # Checks if the model has been loaded in 4-bit or 8-bit with BNB + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if dtype_present_in_args: + raise ValueError( + "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the" + " desired `dtype` by passing the correct `torch_dtype` argument." + ) + + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"): + raise ValueError( + "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ: + if dtype_present_in_args: + raise ValueError( + "You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired" + " `dtype` by passing the correct `torch_dtype` argument." + ) + return super().to(*args, **kwargs) + + def half(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.half()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been casted to the correct `dtype`." + ) + else: + return super().half(*args) + + def float(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.float()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been casted to the correct `dtype`." + ) + else: + return super().float(*args) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: bool = None, + weights_only: bool = True, + **kwargs, + ) -> "PreTrainedModel": + r""" + Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + If model weights are the same precision as the base model (and is a supported model), weights will be lazily loaded + in using the `meta` device and brought into memory once an input is passed through that layer regardless of + `low_cpu_mem_usage`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g, + `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to + `True`. + - `None` if you are both providing the configuration and state dictionary (resp. with keyword + arguments `config` and `state_dict`). + model_args (sequence of positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): + Can be either: + + - an instance of a class derived from [`PretrainedConfig`], + - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. + + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + state_dict (`Dict[str, torch.Tensor]`, *optional*): + A state dictionary to use instead of a state dictionary loaded from saved weights file. + + This option can be used if you want to create a model from a pretrained configuration but load your own + weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and + [`~PreTrainedModel.from_pretrained`] is not a simpler option. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_tf (`bool`, *optional*, defaults to `False`): + Load the model weights from a TensorFlow checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + _fast_init(`bool`, *optional*, defaults to `True`): + Whether or not to disable fast initialization. + + + + One should only disable *_fast_init* to ensure backwards compatibility with `transformers.__version__ < + 4.6.0` for seeded model initialization. This argument will be removed at the next major version. See + [pull request 11471](https://github.com/huggingface/transformers/pull/11471) for more information. + + + attn_implementation (`str`, *optional*): + The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. + + > Parameters for big model inference + + low_cpu_mem_usage(`bool`, *optional*): + Tries not to use more than 1x model size in CPU memory (including peak memory) while loading the model. + Generally should be combined with a `device_map` (such as `"auto"`) for best results. + This is an experimental feature and a subject to change at any moment. + + If the model weights are in the same precision as the model loaded in, `low_cpu_mem_usage` (without + `device_map`) is redundant and will not provide any benefit in regards to CPU memory usage. However, + this should still be enabled if you are passing in a `device_map`. + + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under a specific `dtype`. The different options + are: + + 1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified + `dtype`, ignoring the model's `config.torch_dtype` if one exists. If not specified + - the model will get loaded in `torch.float` (fp32). + + 2. `"auto"` - A `torch_dtype` entry in the `config.json` file of the model will be + attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in + the checkpoint that's of a floating point type and use that as `dtype`. This will load the model + using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how + the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32. + + 3. A string that is a valid `torch.dtype`. E.g. "float32" loads the model in `torch.float32`, "float16" loads in `torch.float16` etc. + + + + For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or + reach out to the authors and ask them to add this information to the model's card and to insert the + `torch_dtype` entry in `config.json` on the hub. + + + + device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank + like `1`) on which the model will be allocated, the device map will map the entire model to this + device. Passing `device_map = 0` means put the whole model on GPU 0. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_state_dict (`bool`, *optional*): + If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU + RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to + `True` when there is some disk offload. + offload_buffers (`bool`, *optional*): + Whether or not to offload the buffers with the model parameters. + quantization_config (`Union[QuantizationConfigMixin,Dict]`, *optional*): + A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g + bitsandbytes, gptq). There may be other quantization-related kwargs, including `load_in_4bit` and + `load_in_8bit`, which are parsed by QuantizationConfigParser. Supported only for bitsandbytes + quantizations and not preferred. consider inserting all such arguments into quantization_config + instead. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_tf` or `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors` + is not installed, it will be set to `False`. + + weights_only (`bool`, *optional*, defaults to `True`): + Indicates whether unpickler should be restricted to loading only tensors, primitive types, + dictionaries and any types added via torch.serialization.add_safe_globals(). + When set to False, we can load wrapper tensor subclass weights. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + Examples: + + ```python + >>> from transformers import BertConfig, BertModel + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased") + >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). + >>> model = BertModel.from_pretrained("./test/saved_model/") + >>> # Update configuration during loading. + >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True) + >>> assert model.config.output_attentions == True + >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). + >>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json") + >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config) + >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower) + >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True) + ``` + + * `low_cpu_mem_usage` algorithm: + + This is an experimental function that loads the model using ~1x model size CPU memory + + Here is how it works: + + 1. save which state_dict keys we have + 2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory + 3. after the model has been instantiated switch to the meta device all params/buffers that + are going to be replaced from the loaded state_dict + 4. load state_dict 2nd time + 5. replace the params/buffers from the state_dict + + Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors + + """ + state_dict = kwargs.pop("state_dict", None) + from_tf = kwargs.pop("from_tf", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + use_auth_token = kwargs.pop("use_auth_token", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + _ = kwargs.pop("mirror", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + _fast_init = kwargs.pop("_fast_init", True) + torch_dtype = kwargs.pop("torch_dtype", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + offload_buffers = kwargs.pop("offload_buffers", False) + load_in_8bit = kwargs.pop("load_in_8bit", False) + load_in_4bit = kwargs.pop("load_in_4bit", False) + quantization_config = kwargs.pop("quantization_config", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + variant = kwargs.pop("variant", None) + adapter_kwargs = kwargs.pop("adapter_kwargs", {}) + adapter_name = kwargs.pop("adapter_name", "default") + use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) + generation_config = kwargs.pop("generation_config", None) + + gguf_file = kwargs.pop("gguf_file", None) + # Cache path to the GGUF file + gguf_path = None + + if is_fsdp_enabled(): + low_cpu_mem_usage = True + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs: + adapter_kwargs["token"] = token + + if use_safetensors is None and not is_safetensors_available(): + use_safetensors = False + if trust_remote_code is True: + logger.warning( + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" + " ignored." + ) + + if gguf_file is not None and not is_accelerate_available(): + raise ValueError("accelerate is required when loading a GGUF file `pip install accelerate`.") + + if commit_hash is None: + if not isinstance(config, PretrainedConfig): + # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible + resolved_config_file = cached_file( + pretrained_model_name_or_path, + CONFIG_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + else: + commit_hash = getattr(config, "_commit_hash", None) + + if is_peft_available(): + _adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None) + + if _adapter_model_path is None: + _adapter_model_path = find_adapter_config_file( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + _commit_hash=commit_hash, + **adapter_kwargs, + ) + if _adapter_model_path is not None and os.path.isfile(_adapter_model_path): + with open(_adapter_model_path, "r", encoding="utf-8") as f: + _adapter_model_path = pretrained_model_name_or_path + pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"] + else: + _adapter_model_path = None + + # change device_map into a map if we passed an int, a str or a torch.device + if isinstance(device_map, torch.device): + device_map = {"": device_map} + elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + try: + device_map = {"": torch.device(device_map)} + except RuntimeError: + raise ValueError( + "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " + f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." + ) + elif isinstance(device_map, int): + if device_map < 0: + raise ValueError( + "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " + ) + else: + device_map = {"": device_map} + + if device_map is not None: + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + elif not low_cpu_mem_usage: + raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") + + if low_cpu_mem_usage: + if is_deepspeed_zero3_enabled(): + raise ValueError( + "DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`." + ) + elif not is_accelerate_available(): + raise ImportError( + f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" + ) + + # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation. + if load_in_4bit or load_in_8bit: + if quantization_config is not None: + raise ValueError( + "You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing " + "`quantization_config` argument at the same time." + ) + + # preparing BitsAndBytesConfig from kwargs + config_dict = {k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters} + config_dict = {**config_dict, "load_in_4bit": load_in_4bit, "load_in_8bit": load_in_8bit} + quantization_config, kwargs = BitsAndBytesConfig.from_dict( + config_dict=config_dict, return_unused_kwargs=True, **kwargs + ) + logger.warning( + "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. " + "Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead." + ) + + from_pt = not (from_tf | from_flax) + + user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + else: + # In case one passes a config to `from_pretrained` + "attn_implementation" + # override the `_attn_implementation` attribute to `attn_implementation` of the kwargs + # Please see: https://github.com/huggingface/transformers/issues/28038 + + # Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory + # we pop attn_implementation from the kwargs but this handles the case where users + # passes manually the config to `from_pretrained`. + config = copy.deepcopy(config) + + kwarg_attn_imp = kwargs.pop("attn_implementation", None) + if kwarg_attn_imp is not None: + config._attn_implementation = kwarg_attn_imp + + model_kwargs = kwargs + + pre_quantized = getattr(config, "quantization_config", None) is not None + if pre_quantized or quantization_config is not None: + if pre_quantized: + config.quantization_config = AutoHfQuantizer.merge_quantization_configs( + config.quantization_config, quantization_config + ) + else: + config.quantization_config = quantization_config + hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=pre_quantized) + else: + hf_quantizer = None + + if hf_quantizer is not None: + hf_quantizer.validate_environment( + torch_dtype=torch_dtype, from_tf=from_tf, from_flax=from_flax, device_map=device_map + ) + torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) + device_map = hf_quantizer.update_device_map(device_map) + + # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` + user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value + + # Force-set to `True` for more mem efficiency + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.warning("`low_cpu_mem_usage` was None, now default to True since model is quantized.") + is_quantized = hf_quantizer is not None + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + is_sharded = False + sharded_metadata = None + # Load model + loading_info = None + + # Keep in fp32 modules + keep_in_fp32_modules = None + use_keep_in_fp32_modules = False + + if gguf_file is not None and hf_quantizer is not None: + raise ValueError( + "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub." + ) + + if pretrained_model_name_or_path is not None and gguf_file is None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if from_tf and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + ): + # Load from a TF 1.0 checkpoint in priority if from_tf + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + elif from_tf and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) + ): + # Load from a TF 2.0 checkpoint in priority if from_tf + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) + elif from_flax and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + ): + # Load from a Flax checkpoint in priority if from_flax + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) + ) + elif use_safetensors is not False and os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + elif not use_safetensors and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) + ) + elif not use_safetensors and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + # At this stage we don't have a weight file so we will raise an error. + elif not use_safetensors and ( + os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")) + or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)) + ): + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use" + " `from_tf=True` to load this model from those weights." + ) + elif not use_safetensors and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + ): + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`" + " to load this model from those weights." + ) + elif use_safetensors: + raise EnvironmentError( + f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path}." + ) + else: + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory" + f" {pretrained_model_name_or_path}." + ) + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")): + if not from_tf: + raise ValueError( + f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set " + "from_tf to True to load from this checkpoint." + ) + archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index") + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + # set correct filename + if from_tf: + filename = TF2_WEIGHTS_NAME + elif from_flax: + filename = FLAX_WEIGHTS_NAME + elif use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + elif use_safetensors: + if revision == "main": + resolved_archive_file, revision, is_sharded = auto_conversion( + pretrained_model_name_or_path, **cached_file_kwargs + ) + cached_file_kwargs["revision"] = revision + if resolved_archive_file is None: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} " + "and thus cannot be loaded with `safetensors`. Please make sure that the model has " + "been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + if not local_files_only and not is_offline_mode(): + if resolved_archive_file is not None: + if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]: + # If the PyTorch file was found, check if there is a safetensors file on the repository + # If there is no safetensors file on the repositories, start an auto conversion + safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "resume_download": resume_download, + "local_files_only": local_files_only, + "user_agent": user_agent, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + **has_file_kwargs, + } + if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): + Thread( + target=auto_conversion, + args=(pretrained_model_name_or_path,), + kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs}, + name="Thread-autoconversion", + ).start() + else: + # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. + # We try those to give a helpful error message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights." + " Use `from_tf=True` to load this model from those weights." + ) + elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use" + " `from_flax=True` to load this model from those weights." + ) + elif variant is not None and has_file( + pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs + ): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." + ) + + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception as e: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." + ) from e + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + elif gguf_file: + from transformers.modeling_gguf_pytorch_utils import load_gguf_checkpoint + + # Case 1: the GGUF file is present locally + if os.path.isfile(gguf_file): + gguf_path = gguf_file + # Case 2: The GGUF path is a location on the Hub + # Load from URL or cache if already cached + else: + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + + gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **cached_file_kwargs) + + state_dict = load_gguf_checkpoint(gguf_path, return_tensors=True)["tensors"] + + resolved_archive_file = None + is_sharded = False + else: + resolved_archive_file = None + + # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + + if ( + is_safetensors_available() + and isinstance(resolved_archive_file, str) + and resolved_archive_file.endswith(".safetensors") + ): + with safe_open(resolved_archive_file, framework="pt") as f: + metadata = f.metadata() + + if metadata.get("format") == "pt": + pass + elif metadata.get("format") == "tf": + from_tf = True + logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.") + elif metadata.get("format") == "flax": + from_flax = True + logger.info("A Flax safetensors file is being loaded in a PyTorch model.") + elif metadata.get("format") == "mlx": + # This is a mlx file, we assume weights are compatible with pt + pass + else: + raise ValueError( + f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}" + ) + + from_pt = not (from_tf | from_flax) + + # load pt weights early so that we know which dtype to init the model under + + if from_pt: + if not is_sharded and state_dict is None: + # Time to load the checkpoint + state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only) + + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first + # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype + # we also may have config.torch_dtype available, but we won't rely on it till v5 + dtype_orig = None + + if torch_dtype is not None: + if isinstance(torch_dtype, str): + if torch_dtype == "auto": + if hasattr(config, "torch_dtype") and config.torch_dtype is not None: + torch_dtype = config.torch_dtype + logger.info(f"Will use torch_dtype={torch_dtype} as defined in model's config object") + else: + if is_sharded and "dtype" in sharded_metadata: + torch_dtype = sharded_metadata["dtype"] + elif not is_sharded: + torch_dtype = get_state_dict_dtype(state_dict) + else: + one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only) + torch_dtype = get_state_dict_dtype(one_state_dict) + del one_state_dict # free CPU memory + logger.info( + "Since the `torch_dtype` attribute can't be found in model's config object, " + "will use torch_dtype={torch_dtype} as derived from model's weights" + ) + elif hasattr(torch, torch_dtype): + torch_dtype = getattr(torch, torch_dtype) + else: + raise ValueError( + f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}' + ) + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + ) + + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + loaded_state_dict_keys = list(state_dict.keys()) + + if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())): + # In case some weights need to be kept in float32 and accelerate is not installed, + # we later on want to take the path where state_dict is not None, that is the one + # that do not require accelerate. + state_dict = None + + config.name_or_path = pretrained_model_name_or_path + + # Instantiate model. + init_contexts = [no_init_weights(_enable=_fast_init)] + + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") + init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts + elif low_cpu_mem_usage: + if not is_accelerate_available(): + raise ImportError( + f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" + ) + init_contexts.append(init_empty_weights()) + + config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. + if not getattr(config, "_attn_implementation_autoset", False): + config = cls._autoset_attn_implementation( + config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map + ) + + with ContextManagers(init_contexts): + # Let's make sure we don't run the init function of buffer modules + model = cls(config, *model_args, **model_kwargs) + + # make sure we use the model's config since the __init__ call might have copied it + config = model.config + + # Check first if we are `from_pt` + if use_keep_in_fp32_modules: + if is_accelerate_available() and not is_deepspeed_zero3_enabled(): + low_cpu_mem_usage = True + keep_in_fp32_modules = model._keep_in_fp32_modules + else: + keep_in_fp32_modules = [] + + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + ) + + # We store the original dtype for quantized models as we cannot easily retrieve it + # once the weights have been quantized + # Note that once you have loaded a quantized model, you can't change its dtype so this will + # remain a single source of truth + config._pre_quantization_dtype = torch_dtype + + if isinstance(device_map, str): + special_dtypes = {} + + if hf_quantizer is not None: + special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype)) + + special_dtypes.update( + { + name: torch.float32 + for name, _ in model.named_parameters() + if any(m in name for m in keep_in_fp32_modules) + } + ) + + target_dtype = torch_dtype + + if hf_quantizer is not None: + target_dtype = hf_quantizer.adjust_target_dtype(target_dtype) + + no_split_modules = model._get_no_split_modules(device_map) + if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + raise ValueError( + "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or " + "'sequential'." + ) + + device_map_kwargs = {"no_split_module_classes": no_split_modules} + if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: + device_map_kwargs["special_dtypes"] = special_dtypes + elif len(special_dtypes) > 0: + logger.warning( + "This model has some weights that should be kept in higher precision, you need to upgrade " + "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." + ) + if device_map != "sequential": + max_memory = get_balanced_memory( + model, + dtype=target_dtype, + low_zero=(device_map == "balanced_low_0"), + max_memory=max_memory, + **device_map_kwargs, + ) + else: + max_memory = get_max_memory(max_memory) + if hf_quantizer is not None: + max_memory = hf_quantizer.adjust_max_memory(max_memory) + device_map_kwargs["max_memory"] = max_memory + + # Make sure tied weights are tied before creating the device map. + model.tie_weights() + device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) + + if hf_quantizer is not None: + hf_quantizer.validate_environment(device_map=device_map) + + elif device_map is not None: + model.tie_weights() + tied_params = find_tied_parameters(model) + # check if we don't have tied param in different devices + check_tied_parameters_on_same_device(tied_params, device_map) + + if from_tf: + if resolved_archive_file.endswith(".index"): + # Load from a TensorFlow 1.X checkpoint - provided by original authors + model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' + else: + # Load from our TensorFlow 2.0 checkpoints + try: + from transformers.modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model + + model, loading_info = load_tf2_checkpoint_in_pytorch_model( + model, resolved_archive_file, allow_missing_keys=True, output_loading_info=True + ) + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed." + " Please see https://pytorch.org/ and https://www.tensorflow.org/install/ for installation" + " instructions." + ) + raise + elif from_flax: + try: + from transformers.modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file) + except ImportError: + logger.error( + "Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see" + " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for" + " installation instructions." + ) + raise + elif from_pt: + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=low_cpu_mem_usage, + device_map=device_map, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + gguf_path=gguf_path, + weights_only=weights_only, + ) + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate() and generation_config is not None: + logger.info("The user-defined `generation_config` will be used to override the default generation config.") + model.generation_config = model.generation_config.from_dict(generation_config.to_dict()) + elif model.can_generate() and pretrained_model_name_or_path is not None: + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except OSError: + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + pass + + # Dispatch model with hooks on all devices if necessary + if device_map is not None: + device_map_kwargs = { + "device_map": device_map, + "offload_dir": offload_folder, + "offload_index": offload_index, + "offload_buffers": offload_buffers, + } + if "skip_keys" in inspect.signature(dispatch_model).parameters: + device_map_kwargs["skip_keys"] = model._skip_keys_device_placement + # For HQQ method we force-set the hooks for single GPU envs + if ( + "force_hooks" in inspect.signature(dispatch_model).parameters + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ + ): + device_map_kwargs["force_hooks"] = True + if ( + hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8 + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + device_map_kwargs["offload_buffers"] = True + + if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): + dispatch_model(model, **device_map_kwargs) + + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model) + model.hf_quantizer = hf_quantizer + + if _adapter_model_path is not None: + model.load_adapter( + _adapter_model_path, + adapter_name=adapter_name, + token=token, + adapter_kwargs=adapter_kwargs, + ) + + if output_loading_info: + if loading_info is None: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + loaded_keys, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + sharded_metadata=None, + _fast_init=True, + low_cpu_mem_usage=False, + device_map=None, + offload_folder=None, + offload_state_dict=None, + dtype=None, + hf_quantizer=None, + keep_in_fp32_modules=None, + gguf_path=None, + weights_only=True, + ): + is_safetensors = False + is_quantized = hf_quantizer is not None + state_dict_folder = None + state_dict_index = None + + if device_map is not None and "disk" in device_map.values(): + archive_file = ( + resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file + ) + is_safetensors = archive_file.endswith(".safetensors") + if offload_folder is None and not is_safetensors: + raise ValueError( + "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" + " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" + " offers the weights in this format." + ) + if offload_folder is not None: + os.makedirs(offload_folder, exist_ok=True) + if offload_state_dict is None: + offload_state_dict = True + + is_sharded_safetensors = is_safetensors and sharded_metadata is not None + + # tie the model weights before retrieving the state_dict + model.tie_weights() + + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + expected_keys = list(model_state_dict.keys()) + prefix = model.base_model_prefix + + if hf_quantizer is not None: + expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys) + + def _fix_key(key): + if "beta" in key: + return key.replace("beta", "bias") + if "gamma" in key: + return key.replace("gamma", "weight") + + # to avoid logging parametrized weight norm renaming + if hasattr(nn.utils.parametrizations, "weight_norm"): + if "weight_g" in key: + return key.replace("weight_g", "parametrizations.weight.original0") + if "weight_v" in key: + return key.replace("weight_v", "parametrizations.weight.original1") + else: + if "parametrizations.weight.original0" in key: + return key.replace("parametrizations.weight.original0", "weight_g") + if "parametrizations.weight.original1" in key: + return key.replace("parametrizations.weight.original1", "weight_v") + return key + + original_loaded_keys = loaded_keys + loaded_keys = [_fix_key(key) for key in loaded_keys] + + if len(prefix) > 0: + has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) + expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) + else: + has_prefix_module = False + expects_prefix_module = False + + # key re-naming operations are never done on the keys + # that are loaded, but always on the keys of the newly initialized model + remove_prefix_from_model = not has_prefix_module and expects_prefix_module + add_prefix_to_model = has_prefix_module and not expects_prefix_module + + if remove_prefix_from_model: + _prefix = f"{prefix}." + expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)] + expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys] + elif add_prefix_to_model: + expected_keys = [".".join([prefix, s]) for s in expected_keys] + + missing_keys = sorted(set(expected_keys) - set(loaded_keys)) + unexpected_keys = set(loaded_keys) - set(expected_keys) + + # Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model + # buffers + model_buffers = {n for n, _ in model.named_buffers()} + if remove_prefix_from_model: + model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers} + elif add_prefix_to_model: + model_buffers = {".".join([prefix, key]) for key in model_buffers} + unexpected_keys = sorted(unexpected_keys - model_buffers) + + model.tie_weights() + if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): + ptrs = collections.defaultdict(list) + for name, tensor in model.state_dict().items(): + id_tensor = id_tensor_storage(tensor) + ptrs[id_tensor].append(name) + + # These are all the pointers of shared tensors. + tied_params = [names for _, names in ptrs.items() if len(names) > 1] + else: + # id function doesn't work for meta tensor so we need this function + tied_params = find_tied_parameters(model) + + for group in tied_params: + if remove_prefix_from_model: + group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group] + elif add_prefix_to_model: + group = [".".join([prefix, key]) for key in group] + missing_in_group = [k for k in missing_keys if k in group] + if len(missing_in_group) > 0 and len(missing_in_group) < len(group): + missing_keys = [k for k in missing_keys if k not in missing_in_group] + + # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if cls._keys_to_ignore_on_load_missing is not None: + for pat in cls._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + if hf_quantizer is not None: + missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) + + # retrieve weights on meta device and put them back on CPU. + # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step + if low_cpu_mem_usage: + for key in missing_keys: + if key in list(model_state_dict.keys()): + key = key + elif f"{prefix}.{key}" in list(model_state_dict.keys()): + key = f"{prefix}.{key}" + elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()): + key = ".".join(key.split(".")[1:]) + param = model_state_dict[key] + + # upcast in fp32 if any + target_dtype = dtype + if ( + keep_in_fp32_modules is not None + and dtype == torch.float16 + and any( + module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules + ) + ): + target_dtype = torch.float32 + + if param.device == torch.device("meta"): + value = torch.empty(*param.size(), dtype=target_dtype) + if ( + not is_quantized + or (getattr(hf_quantizer, "requires_parameters_quantization", False)) + or not hf_quantizer.check_quantized_param( + model, param_value=value, param_name=key, state_dict={} + ) + ): + set_module_tensor_to_device(model, key, "cpu", value) + else: + hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict, unexpected_keys) + + # retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights. + if _fast_init: + if not ignore_mismatched_sizes: + if remove_prefix_from_model: + _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] + elif add_prefix_to_model: + _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] + else: + _loaded_keys = loaded_keys + not_initialized_submodules = set_initialized_submodules(model, _loaded_keys) + # If we're about to tie the output embeds to the input embeds we don't need to init them + if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings: + output_embeddings = model.get_output_embeddings() + if output_embeddings is not None: + # Still need to initialize if there is a bias term since biases are not tied. + if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None: + output_embeddings._is_hf_initialized = True + else: + not_initialized_submodules = dict(model.named_modules()) + # This will only initialize submodules that are not marked as initialized by the line above. + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + not_initialized_parameters = list( + set( + itertools.chain.from_iterable( + submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values() + ) + ) + ) + with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0): + model.apply(model._initialize_weights) + else: + model.apply(model._initialize_weights) + + # Set some modules to fp32 if any + if keep_in_fp32_modules is not None: + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): + # param = param.to(torch.float32) does not work here as only in the local scope. + param.data = param.data.to(torch.float32) + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = "" + model_to_load = model + if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: + start_prefix = cls.base_model_prefix + "." + if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: + model_to_load = getattr(model, cls.base_model_prefix) + base_model_expected_keys = list(model_to_load.state_dict().keys()) + if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): + raise ValueError( + "The state dictionary of the model you are trying to load is corrupted. Are you sure it was " + "properly saved?" + ) + if device_map is not None: + device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()} + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + # If the checkpoint is sharded, we may not have the key here. + if checkpoint_key not in state_dict: + continue + model_key = checkpoint_key + if remove_prefix_from_model: + # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. + model_key = f"{prefix}.{checkpoint_key}" + elif add_prefix_to_model: + # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. + model_key = ".".join(checkpoint_key.split(".")[1:]) + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + if ( + state_dict[checkpoint_key].shape[-1] == 1 + and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel() + ): + # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. + # Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights. + pass + else: + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if resolved_archive_file is not None: + folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) + else: + folder = None + if device_map is not None and is_safetensors: + param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) + str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" + if sharded_metadata is None: + archive_file = ( + resolved_archive_file[0] + if isinstance(resolved_archive_file, (list, tuple)) + else resolved_archive_file + ) + weight_map = {p: archive_file for p in original_loaded_keys} + else: + weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} + offload_index = { + p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} + for p, f in weight_map.items() + if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" + } + else: + offload_index = None + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ) + + # For GGUF models `state_dict` is never set to None as the state dict is always small + if gguf_path: + error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + start_prefix, + expected_keys, + device_map=device_map, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_folder=state_dict_folder, + state_dict_index=state_dict_index, + dtype=dtype, + hf_quantizer=hf_quantizer, + is_safetensors=is_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + ) + else: + # Sharded checkpoint or whole but low_cpu_mem_usage==True + assign_to_params_buffers = check_support_param_buffer_assignment( + model_to_load, state_dict, start_prefix + ) + error_msgs = _load_state_dict_into_model( + model_to_load, state_dict, start_prefix, assign_to_params_buffers + ) + + else: + # This should always be a list but, just to be sure. + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] + + error_msgs = [] + mismatched_keys = [] + if not is_safetensors: + offload_index = {} if device_map is not None and "disk" in device_map.values() else None + if offload_state_dict: + state_dict_folder = tempfile.mkdtemp() + state_dict_index = {} + else: + state_dict_folder = None + state_dict_index = None + + if is_sharded_safetensors: + disk_only_shard_files = get_disk_only_shard_files( + device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix + ) + disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] + else: + disk_only_shard_files = [] + + if len(resolved_archive_file) > 1: + resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") + assign_to_params_buffers = None + for shard_file in resolved_archive_file: + # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. + if shard_file in disk_only_shard_files: + continue + map_location = None + if ( + device_map is not None + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + and hf_quantizer.quantization_config.quant_type == "int4_weight_only" + ): + map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + state_dict = load_state_dict( + shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only + ) + + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys += _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ) + if low_cpu_mem_usage: + if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: + for key, param in model_to_load.state_dict().items(): + if param.device == torch.device("meta"): + set_module_tensor_to_device( + model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) + ) + else: + new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + start_prefix, + expected_keys, + device_map=device_map, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_folder=state_dict_folder, + state_dict_index=state_dict_index, + dtype=dtype, + hf_quantizer=hf_quantizer, + is_safetensors=is_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + ) + error_msgs += new_error_msgs + else: + # Sharded checkpoint or whole but low_cpu_mem_usage==True + if assign_to_params_buffers is None: + assign_to_params_buffers = check_support_param_buffer_assignment( + model_to_load, state_dict, start_prefix + ) + error_msgs += _load_state_dict_into_model( + model_to_load, state_dict, start_prefix, assign_to_params_buffers + ) + + # force memory release + del state_dict + gc.collect() + + if offload_index is not None and len(offload_index) > 0: + if model != model_to_load: + # We need to add the prefix of the base model + prefix = cls.base_model_prefix + if not is_safetensors: + for weight_name in offload_index: + shutil.move( + os.path.join(offload_folder, f"{weight_name}.dat"), + os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"), + ) + offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()} + if not is_safetensors: + save_offload_index(offload_index, offload_folder) + offload_index = None + + if offload_state_dict: + # Load back temporarily offloaded state dict + load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) + shutil.rmtree(state_dict_folder) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + archs = [] if model.config.architectures is None else model.config.architectures + warner = logger.warning if model.__class__.__name__ in archs else logger.info + warner( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs + + def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): + module_keys = {".".join(key.split(".")[:-1]) for key in names} + + # torch.nn.ParameterList is a special case where two parameter keywords + # are appended to the module name, *e.g.* bert.special_embeddings.0 + module_keys = module_keys.union( + {".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()} + ) + + retrieved_modules = [] + # retrieve all modules that has at least one missing weight name + for name, module in self.named_modules(): + if remove_prefix: + _prefix = f"{self.base_model_prefix}." + name = name[len(_prefix) :] if name.startswith(_prefix) else name + elif add_prefix: + name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix + + if name in module_keys: + retrieved_modules.append(module) + + return retrieved_modules + + @staticmethod + def _load_pretrained_model_low_mem( + model, + loaded_state_dict_keys, + resolved_archive_file, + start_prefix="", + hf_quantizer=None, + pretrained_model_name_or_path=None, + weights_only=True, + ): + """ + This is an experimental function that loads the model using ~1.x model size CPU memory + + Before you call it do: + + 1. save which state_dict keys are available + 2. drop state_dict before model is created, since the latter takes 1x model size memory + + Here then we continue: + + 3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict + 4. load state_dict 2nd time + 5. replace the params/buffers from the state_dict + + Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. To + handle bitsandbytes, needs non-empty hf_quantizer argument. + """ + + _move_model_to_meta(model, loaded_state_dict_keys, start_prefix) + state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only) + expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys + error_msgs = _load_state_dict_into_meta_model( + model, + state_dict, + start_prefix, + expected_keys=expected_keys, + hf_quantizer=hf_quantizer, + ) + return error_msgs + + @classmethod + def register_for_auto_class(cls, auto_class="AutoModel"): + """ + Register this class with a given auto class. This should only be used for custom models as the ones in the + library are already mapped with an auto class. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`): + The auto class to register this new model with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + def to_bettertransformer(self) -> "PreTrainedModel": + """ + Converts the model to use [PyTorch's native attention + implementation](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html), integrated to + Transformers through [Optimum library](https://huggingface.co/docs/optimum/bettertransformer/overview). Only a + subset of all Transformers models are supported. + + PyTorch's attention fastpath allows to speed up inference through kernel fusions and the use of [nested + tensors](https://pytorch.org/docs/stable/nested.html). Detailed benchmarks can be found in [this blog + post](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2). + + Returns: + [`PreTrainedModel`]: The model converted to BetterTransformer. + """ + if not is_optimum_available(): + raise ImportError("The package `optimum` is required to use Better Transformer.") + + from optimum.version import __version__ as optimum_version + + if version.parse(optimum_version) < version.parse("1.7.0"): + raise ImportError( + f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found." + ) + + from optimum.bettertransformer import BetterTransformer + + return BetterTransformer.transform(self) + + def reverse_bettertransformer(self): + """ + Reverts the transformation from [`~PreTrainedModel.to_bettertransformer`] so that the original modeling is + used, for example in order to save the model. + + Returns: + [`PreTrainedModel`]: The model converted back to the original modeling. + """ + if not is_optimum_available(): + raise ImportError("The package `optimum` is required to use Better Transformer.") + + from optimum.version import __version__ as optimum_version + + if version.parse(optimum_version) < version.parse("1.7.0"): + raise ImportError( + f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found." + ) + + from optimum.bettertransformer import BetterTransformer + + return BetterTransformer.reverse(self) + + def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask): + """ + Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given. + """ + + # Skip the check during tracing. + if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling(): + return + + if (attention_mask is not None) or (self.config.pad_token_id is None): + return + + # Check only the first and last input IDs to reduce overhead. + if self.config.pad_token_id in input_ids[:, [-1, 0]]: + warn_string = ( + "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See " + "https://huggingface.co/docs/transformers/troubleshooting" + "#incorrect-output-when-padding-tokens-arent-masked." + ) + + # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an + # attention_mask or not. In this case, we should still show a warning because this is a rare case. + if ( + (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id) + or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id) + or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id) + ): + warn_string += ( + f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical " + f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), " + f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded." + ) + + logger.warning_once(warn_string) + + @property + def _is_quantized_training_enabled(self): + warnings.warn( + "`_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead", + FutureWarning, + ) + + if not hasattr(self, "hf_quantizer"): + return False + + return self.hf_quantizer.is_trainable + + @property + def loss_function(self): + if getattr(self.config, "loss_type", None) is not None: + loss_type = self.config.loss_type + else: + loss_type = self.__class__.__name__ + if loss_type not in LOSS_MAPPING: + loss_groups = f"({'|'.join(LOSS_MAPPING)})" + loss_type = re.findall(loss_groups, self.__class__.__name__) + if len(loss_type) > 0: + loss_type = loss_type[0] + else: + loss_type = None + if loss_type is None or loss_type not in LOSS_MAPPING and getattr(self.config, "loss_type", None) is not None: + logger.warning_once( + f"`loss_type={loss_type}` was set in the config but it is unrecognised." + f"Using the default loss: `ForCausalLMLoss`." + ) + loss_type = "ForCausalLM" + return LOSS_MAPPING[loss_type] + + +PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) +if PreTrainedModel.push_to_hub.__doc__ is not None: + PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format( + object="model", object_class="AutoModel", object_files="model file" + ) + + +class PoolerStartLogits(nn.Module): + """ + Compute SQuAD start logits from sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, 1) + + def forward( + self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + Returns: + `torch.FloatTensor`: The start logits for SQuAD. + """ + x = self.dense(hidden_states).squeeze(-1) + + if p_mask is not None: + if get_parameter_dtype(self) == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class PoolerEndLogits(nn.Module): + """ + Compute SQuAD end logits from sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense_1 = nn.Linear(config.hidden_size, 1) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The end logits for SQuAD. + """ + assert ( + start_states is not None or start_positions is not None + ), "One of start_states, start_positions should be not None" + if start_positions is not None: + slen, hsz = hidden_states.shape[-2:] + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) + start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) + + x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) + x = self.activation(x) + x = self.LayerNorm(x) + x = self.dense_1(x).squeeze(-1) + + if p_mask is not None: + if get_parameter_dtype(self) == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class PoolerAnswerClass(nn.Module): + """ + Compute SQuAD 2.0 answer class from classification and start tokens hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The SQuAD 2.0 answer class. + """ + # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample. + hsz = hidden_states.shape[-1] + assert ( + start_states is not None or start_positions is not None + ), "One of start_states, start_positions should be not None" + if start_positions is not None: + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) + + if cls_index is not None: + cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) + else: + cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) + + x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) + x = self.activation(x) + x = self.dense_1(x).squeeze(-1) + + return x + + +@dataclass +class SquadHeadOutput(ModelOutput): + """ + Base class for outputs of question answering models using a [`~modeling_utils.SQuADHead`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided): + Classification loss as the sum of start token, end token (and is_impossible if provided) classification + losses. + start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top config.start_n_top start token possibilities (beam-search). + start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top config.start_n_top start token possibilities (beam-search). + end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities + (beam-search). + end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search). + cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the `is_impossible` label of the answers. + + """ + + loss: Optional[torch.FloatTensor] = None + start_top_log_probs: Optional[torch.FloatTensor] = None + start_top_index: Optional[torch.LongTensor] = None + end_top_log_probs: Optional[torch.FloatTensor] = None + end_top_index: Optional[torch.LongTensor] = None + cls_logits: Optional[torch.FloatTensor] = None + + +class SQuADHead(nn.Module): + r""" + A SQuAD head inspired by XLNet. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config): + super().__init__() + self.start_n_top = config.start_n_top + self.end_n_top = config.end_n_top + + self.start_logits = PoolerStartLogits(config) + self.end_logits = PoolerEndLogits(config) + self.answer_class = PoolerAnswerClass(config) + + @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig) + def forward( + self, + hidden_states: torch.FloatTensor, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + is_impossible: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + return_dict: bool = False, + ) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + Final hidden states of the model on the sequence tokens. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Positions of the first token for the labeled span. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Positions of the last token for the labeled span. + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Whether the question has a possible answer in the paragraph or not. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + """ + start_logits = self.start_logits(hidden_states, p_mask=p_mask) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, let's remove the dimension added by batch splitting + for x in (start_positions, end_positions, cls_index, is_impossible): + if x is not None and x.dim() > 1: + x.squeeze_(-1) + + # during training, compute the end logits based on the ground truth of the start position + end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) + + loss_fct = CrossEntropyLoss() + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if cls_index is not None and is_impossible is not None: + # Predict answerability from the representation of CLS and START + cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) + loss_fct_cls = nn.BCEWithLogitsLoss() + cls_loss = loss_fct_cls(cls_logits, is_impossible) + + # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss + total_loss += cls_loss * 0.5 + + return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,) + + else: + # during inference, compute the end logits based on beam search + bsz, slen, hsz = hidden_states.size() + start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen) + + start_top_log_probs, start_top_index = torch.topk( + start_log_probs, self.start_n_top, dim=-1 + ) # shape (bsz, start_n_top) + start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) + start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) + start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) + + hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( + start_states + ) # shape (bsz, slen, start_n_top, hsz) + p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None + end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) + end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) + + end_top_log_probs, end_top_index = torch.topk( + end_log_probs, self.end_n_top, dim=1 + ) # shape (bsz, end_n_top, start_n_top) + end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) + end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) + + start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) + cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) + + if not return_dict: + return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + else: + return SquadHeadOutput( + start_top_log_probs=start_top_log_probs, + start_top_index=start_top_index, + end_top_log_probs=end_top_log_probs, + end_top_index=end_top_index, + cls_logits=cls_logits, + ) + + +class SequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else Identity() + + self.first_dropout = Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + +def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module: + """ + Recursively unwraps a model from potential containers (as used in distributed training). + + Args: + model (`torch.nn.Module`): The model to unwrap. + recursive (`bool`, *optional*, defaults to `False`): + Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers + recursively, not just the top-level distributed containers. + """ + # Use accelerate implementation if available (should always be the case when using torch) + # This is for pytorch, as we also have to handle things like dynamo + if is_accelerate_available(): + kwargs = {} + if recursive: + if not is_accelerate_available("0.29.0"): + raise RuntimeError( + "Setting `recursive=True` to `unwrap_model` requires `accelerate` v0.29.0. Please upgrade your version of accelerate" + ) + else: + kwargs["recursive"] = recursive + return extract_model_from_parallel(model, **kwargs) + else: + # since there could be multiple levels of wrapping, unwrap recursively + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model + + +def expand_device_map(device_map, param_names, start_prefix): + """ + Expand a device map to return the correspondance parameter name to device. + """ + new_device_map = {} + param_names = [p[len(start_prefix) :] for p in param_names if p.startswith(start_prefix)] + for module, device in device_map.items(): + new_device_map.update( + {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} + ) + return new_device_map + + +def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): + """ + Returns the list of shard files containing only weights offloaded to disk. + """ + + weight_map = { + p[len(start_prefix) :]: v for p, v in sharded_metadata["weight_map"].items() if p.startswith(start_prefix) + } + files_content = collections.defaultdict(list) + for weight_name, filename in weight_map.items(): + while len(weight_name) > 0 and weight_name not in device_map: + weight_name = ".".join(weight_name.split(".")[:-1]) + files_content[filename].append(device_map[weight_name]) + + return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] diff --git a/indextts/infer.py b/indextts/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..9ef3b7cf00873ce79dff8efbf6cc5827641fdaba --- /dev/null +++ b/indextts/infer.py @@ -0,0 +1,677 @@ +import os + +os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache' +import time +from subprocess import CalledProcessError +from typing import Dict, List + +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +from omegaconf import OmegaConf +from tqdm import tqdm + +import warnings + +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +from indextts.BigVGAN.models import BigVGAN as Generator +from indextts.gpt.model import UnifiedVoice +from indextts.utils.checkpoint import load_checkpoint +from indextts.utils.feature_extractors import MelSpectrogramFeatures + +from indextts.utils.front import TextNormalizer, TextTokenizer + + +class IndexTTS: + def __init__( + self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, device=None, + use_cuda_kernel=None, + ): + """ + Args: + cfg_path (str): path to the config file. + model_dir (str): path to the model directory. + is_fp16 (bool): whether to use fp16. + device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS. + use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device. + """ + if device is not None: + self.device = device + self.is_fp16 = False if device == "cpu" else is_fp16 + self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda") + elif torch.cuda.is_available(): + self.device = "cuda:0" + self.is_fp16 = is_fp16 + self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel + elif hasattr(torch, "mps") and torch.backends.mps.is_available(): + self.device = "mps" + self.is_fp16 = False # Use float16 on MPS is overhead than float32 + self.use_cuda_kernel = False + else: + self.device = "cpu" + self.is_fp16 = False + self.use_cuda_kernel = False + print(">> Be patient, it may take a while to run in CPU mode.") + + self.cfg = OmegaConf.load(cfg_path) + self.model_dir = model_dir + self.dtype = torch.float16 if self.is_fp16 else None + self.stop_mel_token = self.cfg.gpt.stop_mel_token + + # Comment-off to load the VQ-VAE model for debugging tokenizer + # https://github.com/index-tts/index-tts/issues/34 + # + # from indextts.vqvae.xtts_dvae import DiscreteVAE + # self.dvae = DiscreteVAE(**self.cfg.vqvae) + # self.dvae_path = os.path.join(self.model_dir, self.cfg.dvae_checkpoint) + # load_checkpoint(self.dvae, self.dvae_path) + # self.dvae = self.dvae.to(self.device) + # if self.is_fp16: + # self.dvae.eval().half() + # else: + # self.dvae.eval() + # print(">> vqvae weights restored from:", self.dvae_path) + self.gpt = UnifiedVoice(**self.cfg.gpt) + self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint) + load_checkpoint(self.gpt, self.gpt_path) + self.gpt = self.gpt.to(self.device) + if self.is_fp16: + self.gpt.eval().half() + else: + self.gpt.eval() + print(">> GPT weights restored from:", self.gpt_path) + if self.is_fp16: + try: + import deepspeed + + use_deepspeed = True + except (ImportError, OSError, CalledProcessError) as e: + use_deepspeed = False + print(f">> DeepSpeed加载失败,回退到标准推理: {e}") + + self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True) + else: + self.gpt.post_init_gpt2_config(use_deepspeed=False, kv_cache=False, half=False) + + if self.use_cuda_kernel: + # preload the CUDA kernel for BigVGAN + try: + from indextts.BigVGAN.alias_free_activation.cuda import load + + anti_alias_activation_cuda = load.load() + print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda) + except: + print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.") + self.use_cuda_kernel = False + self.bigvgan = Generator(self.cfg.bigvgan, use_cuda_kernel=self.use_cuda_kernel) + self.bigvgan_path = os.path.join(self.model_dir, self.cfg.bigvgan_checkpoint) + vocoder_dict = torch.load(self.bigvgan_path, map_location="cpu") + self.bigvgan.load_state_dict(vocoder_dict["generator"]) + self.bigvgan = self.bigvgan.to(self.device) + # remove weight norm on eval mode + self.bigvgan.remove_weight_norm() + self.bigvgan.eval() + print(">> bigvgan weights restored from:", self.bigvgan_path) + self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"]) + self.normalizer = TextNormalizer() + self.normalizer.load() + print(">> TextNormalizer loaded") + self.tokenizer = TextTokenizer(self.bpe_path, self.normalizer) + print(">> bpe model loaded from:", self.bpe_path) + # 缓存参考音频mel: + self.cache_audio_prompt = None + self.cache_cond_mel = None + # 进度引用显示(可选) + self.gr_progress = None + self.model_version = self.cfg.version if hasattr(self.cfg, "version") else None + + def remove_long_silence(self, codes: torch.Tensor, silent_token=52, max_consecutive=30): + """ + Shrink special tokens (silent_token and stop_mel_token) in codes + codes: [B, T] + """ + code_lens = [] + codes_list = [] + device = codes.device + dtype = codes.dtype + isfix = False + for i in range(0, codes.shape[0]): + code = codes[i] + if not torch.any(code == self.stop_mel_token).item(): + len_ = code.size(0) + else: + stop_mel_idx = (code == self.stop_mel_token).nonzero(as_tuple=False) + len_ = stop_mel_idx[0].item() if len(stop_mel_idx) > 0 else code.size(0) + + count = torch.sum(code == silent_token).item() + if count > max_consecutive: + # code = code.cpu().tolist() + ncode_idx = [] + n = 0 + for k in range(len_): + assert code[ + k] != self.stop_mel_token, f"stop_mel_token {self.stop_mel_token} should be shrinked here" + if code[k] != silent_token: + ncode_idx.append(k) + n = 0 + elif code[k] == silent_token and n < 10: + ncode_idx.append(k) + n += 1 + # if (k == 0 and code[k] == 52) or (code[k] == 52 and code[k-1] == 52): + # n += 1 + # new code + len_ = len(ncode_idx) + codes_list.append(code[ncode_idx]) + isfix = True + else: + # shrink to len_ + codes_list.append(code[:len_]) + code_lens.append(len_) + if isfix: + if len(codes_list) > 1: + codes = pad_sequence(codes_list, batch_first=True, padding_value=self.stop_mel_token) + else: + codes = codes_list[0].unsqueeze(0) + else: + # unchanged + pass + # clip codes to max length + max_len = max(code_lens) + if max_len < codes.shape[1]: + codes = codes[:, :max_len] + code_lens = torch.tensor(code_lens, dtype=torch.long, device=device) + return codes, code_lens + + def bucket_sentences(self, sentences, bucket_max_size=4) -> List[List[Dict]]: + """ + Sentence data bucketing. + if ``bucket_max_size=1``, return all sentences in one bucket. + """ + outputs: List[Dict] = [] + for idx, sent in enumerate(sentences): + outputs.append({"idx": idx, "sent": sent, "len": len(sent)}) + + if len(outputs) > bucket_max_size: + # split sentences into buckets by sentence length + buckets: List[List[Dict]] = [] + factor = 1.5 + last_bucket = None + last_bucket_sent_len_median = 0 + + for sent in sorted(outputs, key=lambda x: x["len"]): + current_sent_len = sent["len"] + if current_sent_len == 0: + print(">> skip empty sentence") + continue + if last_bucket is None \ + or current_sent_len >= int(last_bucket_sent_len_median * factor) \ + or len(last_bucket) >= bucket_max_size: + # new bucket + buckets.append([sent]) + last_bucket = buckets[-1] + last_bucket_sent_len_median = current_sent_len + else: + # current bucket can hold more sentences + last_bucket.append(sent) # sorted + mid = len(last_bucket) // 2 + last_bucket_sent_len_median = last_bucket[mid]["len"] + last_bucket = None + # merge all buckets with size 1 + out_buckets: List[List[Dict]] = [] + only_ones: List[Dict] = [] + for b in buckets: + if len(b) == 1: + only_ones.append(b[0]) + else: + out_buckets.append(b) + if len(only_ones) > 0: + # merge into previous buckets if possible + # print("only_ones:", [(o["idx"], o["len"]) for o in only_ones]) + for i in range(len(out_buckets)): + b = out_buckets[i] + if len(b) < bucket_max_size: + b.append(only_ones.pop(0)) + if len(only_ones) == 0: + break + # combined all remaining sized 1 buckets + if len(only_ones) > 0: + out_buckets.extend( + [only_ones[i:i + bucket_max_size] for i in range(0, len(only_ones), bucket_max_size)]) + return out_buckets + return [outputs] + + def pad_tokens_cat(self, tokens: List[torch.Tensor]) -> torch.Tensor: + if self.model_version and self.model_version >= 1.5: + # 1.5版本以上,直接使用stop_text_token 右侧填充,填充到最大长度 + # [1, N] -> [N,] + tokens = [t.squeeze(0) for t in tokens] + return pad_sequence(tokens, batch_first=True, padding_value=self.cfg.gpt.stop_text_token, + padding_side="right") + max_len = max(t.size(1) for t in tokens) + outputs = [] + for tensor in tokens: + pad_len = max_len - tensor.size(1) + if pad_len > 0: + n = min(8, pad_len) + tensor = torch.nn.functional.pad(tensor, (0, n), value=self.cfg.gpt.stop_text_token) + tensor = torch.nn.functional.pad(tensor, (0, pad_len - n), value=self.cfg.gpt.start_text_token) + tensor = tensor[:, :max_len] + outputs.append(tensor) + tokens = torch.cat(outputs, dim=0) + return tokens + + def torch_empty_cache(self): + try: + if "cuda" in str(self.device): + torch.cuda.empty_cache() + elif "mps" in str(self.device): + torch.mps.empty_cache() + except Exception as e: + pass + + def _set_gr_progress(self, value, desc): + if self.gr_progress is not None: + self.gr_progress(value, desc=desc) + + # 快速推理:对于“多句长文本”,可实现至少 2~10 倍以上的速度提升~ (First modified by sunnyboxs 2025-04-16) + def infer_fast(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=100, + sentences_bucket_max_size=4, **generation_kwargs): + """ + Args: + ``max_text_tokens_per_sentence``: 分句的最大token数,默认``100``,可以根据GPU硬件情况调整 + - 越小,batch 越多,推理速度越*快*,占用内存更多,可能影响质量 + - 越大,batch 越少,推理速度越*慢*,占用内存和质量更接近于非快速推理 + ``sentences_bucket_max_size``: 分句分桶的最大容量,默认``4``,可以根据GPU内存调整 + - 越大,bucket数量越少,batch越多,推理速度越*快*,占用内存更多,可能影响质量 + - 越小,bucket数量越多,batch越少,推理速度越*慢*,占用内存和质量更接近于非快速推理 + """ + print(">> start fast inference...") + + self._set_gr_progress(0, "start fast inference...") + if verbose: + print(f"origin text:{text}") + start_time = time.perf_counter() + + # 如果参考音频改变了,才需要重新生成 cond_mel, 提升速度 + if self.cache_cond_mel is None or self.cache_audio_prompt != audio_prompt: + audio, sr = torchaudio.load(audio_prompt) + audio = torch.mean(audio, dim=0, keepdim=True) + if audio.shape[0] > 1: + audio = audio[0].unsqueeze(0) + audio = torchaudio.transforms.Resample(sr, 24000)(audio) + cond_mel = MelSpectrogramFeatures()(audio).to(self.device) + cond_mel_frame = cond_mel.shape[-1] + if verbose: + print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype) + + self.cache_audio_prompt = audio_prompt + self.cache_cond_mel = cond_mel + else: + cond_mel = self.cache_cond_mel + cond_mel_frame = cond_mel.shape[-1] + pass + + auto_conditioning = cond_mel + cond_mel_lengths = torch.tensor([cond_mel_frame], device=self.device) + + # text_tokens + text_tokens_list = self.tokenizer.tokenize(text) + + sentences = self.tokenizer.split_sentences(text_tokens_list, + max_tokens_per_sentence=max_text_tokens_per_sentence) + if verbose: + print(">> text token count:", len(text_tokens_list)) + print(" splited sentences count:", len(sentences)) + print(" max_text_tokens_per_sentence:", max_text_tokens_per_sentence) + print(*sentences, sep="\n") + do_sample = generation_kwargs.pop("do_sample", True) + top_p = generation_kwargs.pop("top_p", 0.8) + top_k = generation_kwargs.pop("top_k", 30) + temperature = generation_kwargs.pop("temperature", 1.0) + autoregressive_batch_size = 1 + length_penalty = generation_kwargs.pop("length_penalty", 0.0) + num_beams = generation_kwargs.pop("num_beams", 3) + repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0) + max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600) + sampling_rate = 24000 + # lang = "EN" + # lang = "ZH" + wavs = [] + gpt_gen_time = 0 + gpt_forward_time = 0 + bigvgan_time = 0 + + # text processing + all_text_tokens: List[List[torch.Tensor]] = [] + self._set_gr_progress(0.1, "text processing...") + bucket_max_size = sentences_bucket_max_size if self.device != "cpu" else 1 + all_sentences = self.bucket_sentences(sentences, bucket_max_size=bucket_max_size) + bucket_count = len(all_sentences) + if verbose: + print(">> sentences bucket_count:", bucket_count, + "bucket sizes:", [(len(s), [t["idx"] for t in s]) for s in all_sentences], + "bucket_max_size:", bucket_max_size) + for sentences in all_sentences: + temp_tokens: List[torch.Tensor] = [] + all_text_tokens.append(temp_tokens) + for item in sentences: + sent = item["sent"] + text_tokens = self.tokenizer.convert_tokens_to_ids(sent) + text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0) + if verbose: + print(text_tokens) + print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") + # debug tokenizer + text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist()) + print("text_token_syms is same as sentence tokens", text_token_syms == sent) + temp_tokens.append(text_tokens) + + # Sequential processing of bucketing data + all_batch_num = sum(len(s) for s in all_sentences) + all_batch_codes = [] + processed_num = 0 + for item_tokens in all_text_tokens: + batch_num = len(item_tokens) + if batch_num > 1: + batch_text_tokens = self.pad_tokens_cat(item_tokens) + else: + batch_text_tokens = item_tokens[0] + processed_num += batch_num + # gpt speech + self._set_gr_progress(0.2 + 0.3 * processed_num / all_batch_num, + f"gpt inference speech... {processed_num}/{all_batch_num}") + m_start_time = time.perf_counter() + with torch.no_grad(): + with torch.amp.autocast(batch_text_tokens.device.type, enabled=self.dtype is not None, + dtype=self.dtype): + temp_codes = self.gpt.inference_speech(auto_conditioning, batch_text_tokens, + cond_mel_lengths=cond_mel_lengths, + # text_lengths=text_len, + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=autoregressive_batch_size, + length_penalty=length_penalty, + num_beams=num_beams, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **generation_kwargs) + all_batch_codes.append(temp_codes) + gpt_gen_time += time.perf_counter() - m_start_time + + # gpt latent + self._set_gr_progress(0.5, "gpt inference latents...") + all_idxs = [] + all_latents = [] + has_warned = False + for batch_codes, batch_tokens, batch_sentences in zip(all_batch_codes, all_text_tokens, all_sentences): + for i in range(batch_codes.shape[0]): + codes = batch_codes[i] # [x] + if not has_warned and codes[-1] != self.stop_mel_token: + warnings.warn( + f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). " + f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.", + category=RuntimeWarning + ) + has_warned = True + codes = codes.unsqueeze(0) # [x] -> [1, x] + if verbose: + print("codes:", codes.shape) + print(codes) + codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30) + if verbose: + print("fix codes:", codes.shape) + print(codes) + print("code_lens:", code_lens) + text_tokens = batch_tokens[i] + all_idxs.append(batch_sentences[i]["idx"]) + m_start_time = time.perf_counter() + with torch.no_grad(): + with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): + latent = \ + self.gpt(auto_conditioning, text_tokens, + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, + code_lens * self.gpt.mel_length_compression, + cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], + device=text_tokens.device), + return_latent=True, clip_inputs=False) + gpt_forward_time += time.perf_counter() - m_start_time + all_latents.append(latent) + del all_batch_codes, all_text_tokens, all_sentences + # bigvgan chunk + chunk_size = 2 + all_latents = [all_latents[all_idxs.index(i)] for i in range(len(all_latents))] + if verbose: + print(">> all_latents:", len(all_latents)) + print(" latents length:", [l.shape[1] for l in all_latents]) + chunk_latents = [all_latents[i: i + chunk_size] for i in range(0, len(all_latents), chunk_size)] + chunk_length = len(chunk_latents) + latent_length = len(all_latents) + + # bigvgan chunk decode + self._set_gr_progress(0.7, "bigvgan decode...") + tqdm_progress = tqdm(total=latent_length, desc="bigvgan") + for items in chunk_latents: + tqdm_progress.update(len(items)) + latent = torch.cat(items, dim=1) + with torch.no_grad(): + with torch.amp.autocast(latent.device.type, enabled=self.dtype is not None, dtype=self.dtype): + m_start_time = time.perf_counter() + wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2)) + bigvgan_time += time.perf_counter() - m_start_time + wav = wav.squeeze(1) + pass + wav = torch.clamp(32767 * wav, -32767.0, 32767.0) + wavs.append(wav.cpu()) # to cpu before saving + + # clear cache + tqdm_progress.close() # 确保进度条被关闭 + del all_latents, chunk_latents + end_time = time.perf_counter() + self.torch_empty_cache() + + # wav audio output + self._set_gr_progress(0.9, "save audio...") + wav = torch.cat(wavs, dim=1) + wav_length = wav.shape[-1] / sampling_rate + print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds") + print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds") + print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds") + print(f">> bigvgan_time: {bigvgan_time:.2f} seconds") + print(f">> Total fast inference time: {end_time - start_time:.2f} seconds") + print(f">> Generated audio length: {wav_length:.2f} seconds") + print(f">> [fast] bigvgan chunk_length: {chunk_length}") + print(f">> [fast] batch_num: {all_batch_num} bucket_max_size: {bucket_max_size}", + f"bucket_count: {bucket_count}" if bucket_max_size > 1 else "") + print(f">> [fast] RTF: {(end_time - start_time) / wav_length:.4f}") + + # save audio + wav = wav.cpu() # to cpu + if output_path: + # 直接保存音频到指定路径中 + os.makedirs(os.path.dirname(output_path), exist_ok=True) + torchaudio.save(output_path, wav.type(torch.int16), sampling_rate) + print(">> wav file saved to:", output_path) + return output_path + else: + # 返回以符合Gradio的格式要求 + wav_data = wav.type(torch.int16) + wav_data = wav_data.numpy().T + return (sampling_rate, wav_data) + + # 原始推理模式 + def infer(self, audio_prompt, text, output_path, verbose=False, max_text_tokens_per_sentence=120, + **generation_kwargs): + print(">> start inference...") + self._set_gr_progress(0, "start inference...") + if verbose: + print(f"origin text:{text}") + start_time = time.perf_counter() + + # 如果参考音频改变了,才需要重新生成 cond_mel, 提升速度 + if self.cache_cond_mel is None or self.cache_audio_prompt != audio_prompt: + audio, sr = torchaudio.load(audio_prompt) + audio = torch.mean(audio, dim=0, keepdim=True) + if audio.shape[0] > 1: + audio = audio[0].unsqueeze(0) + audio = torchaudio.transforms.Resample(sr, 24000)(audio) + cond_mel = MelSpectrogramFeatures()(audio).to(self.device) + cond_mel_frame = cond_mel.shape[-1] + if verbose: + print(f"cond_mel shape: {cond_mel.shape}", "dtype:", cond_mel.dtype) + + self.cache_audio_prompt = audio_prompt + self.cache_cond_mel = cond_mel + else: + cond_mel = self.cache_cond_mel + cond_mel_frame = cond_mel.shape[-1] + pass + + self._set_gr_progress(0.1, "text processing...") + auto_conditioning = cond_mel + text_tokens_list = self.tokenizer.tokenize(text) + sentences = self.tokenizer.split_sentences(text_tokens_list, max_text_tokens_per_sentence) + if verbose: + print("text token count:", len(text_tokens_list)) + print("sentences count:", len(sentences)) + print("max_text_tokens_per_sentence:", max_text_tokens_per_sentence) + print(*sentences, sep="\n") + do_sample = generation_kwargs.pop("do_sample", True) + top_p = generation_kwargs.pop("top_p", 0.8) + top_k = generation_kwargs.pop("top_k", 30) + temperature = generation_kwargs.pop("temperature", 1.0) + autoregressive_batch_size = 1 + length_penalty = generation_kwargs.pop("length_penalty", 0.0) + num_beams = generation_kwargs.pop("num_beams", 3) + repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0) + max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 600) + sampling_rate = 24000 + # lang = "EN" + # lang = "ZH" + wavs = [] + gpt_gen_time = 0 + gpt_forward_time = 0 + bigvgan_time = 0 + progress = 0 + has_warned = False + for sent in sentences: + text_tokens = self.tokenizer.convert_tokens_to_ids(sent) + text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0) + # text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + # text_tokens = F.pad(text_tokens, (1, 0), value=0) + # text_tokens = F.pad(text_tokens, (0, 1), value=1) + if verbose: + print(text_tokens) + print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") + # debug tokenizer + text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist()) + print("text_token_syms is same as sentence tokens", text_token_syms == sent) + + # text_len = torch.IntTensor([text_tokens.size(1)], device=text_tokens.device) + # print(text_len) + progress += 1 + self._set_gr_progress(0.2 + 0.4 * (progress - 1) / len(sentences), + f"gpt inference latent... {progress}/{len(sentences)}") + m_start_time = time.perf_counter() + with torch.no_grad(): + with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): + codes = self.gpt.inference_speech(auto_conditioning, text_tokens, + cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], + device=text_tokens.device), + # text_lengths=text_len, + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=autoregressive_batch_size, + length_penalty=length_penalty, + num_beams=num_beams, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **generation_kwargs) + gpt_gen_time += time.perf_counter() - m_start_time + if not has_warned and (codes[:, -1] != self.stop_mel_token).any(): + warnings.warn( + f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). " + f"Input text tokens: {text_tokens.shape[1]}. " + f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.", + category=RuntimeWarning + ) + has_warned = True + + code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype) + if verbose: + print(codes, type(codes)) + print(f"codes shape: {codes.shape}, codes type: {codes.dtype}") + print(f"code len: {code_lens}") + + # remove ultra-long silence if exits + # temporarily fix the long silence bug. + codes, code_lens = self.remove_long_silence(codes, silent_token=52, max_consecutive=30) + if verbose: + print(codes, type(codes)) + print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}") + print(f"code len: {code_lens}") + self._set_gr_progress(0.2 + 0.4 * progress / len(sentences), + f"gpt inference speech... {progress}/{len(sentences)}") + m_start_time = time.perf_counter() + # latent, text_lens_out, code_lens_out = \ + with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): + latent = \ + self.gpt(auto_conditioning, text_tokens, + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), codes, + code_lens * self.gpt.mel_length_compression, + cond_mel_lengths=torch.tensor([auto_conditioning.shape[-1]], + device=text_tokens.device), + return_latent=True, clip_inputs=False) + gpt_forward_time += time.perf_counter() - m_start_time + + m_start_time = time.perf_counter() + wav, _ = self.bigvgan(latent, auto_conditioning.transpose(1, 2)) + bigvgan_time += time.perf_counter() - m_start_time + wav = wav.squeeze(1) + + wav = torch.clamp(32767 * wav, -32767.0, 32767.0) + if verbose: + print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max()) + # wavs.append(wav[:, :-512]) + wavs.append(wav.cpu()) # to cpu before saving + end_time = time.perf_counter() + self._set_gr_progress(0.9, "save audio...") + wav = torch.cat(wavs, dim=1) + wav_length = wav.shape[-1] / sampling_rate + print(f">> Reference audio length: {cond_mel_frame * 256 / sampling_rate:.2f} seconds") + print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds") + print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds") + print(f">> bigvgan_time: {bigvgan_time:.2f} seconds") + print(f">> Total inference time: {end_time - start_time:.2f} seconds") + print(f">> Generated audio length: {wav_length:.2f} seconds") + print(f">> RTF: {(end_time - start_time) / wav_length:.4f}") + + # save audio + wav = wav.cpu() # to cpu + if output_path: + # 直接保存音频到指定路径中 + if os.path.isfile(output_path): + os.remove(output_path) + print(">> remove old wav file:", output_path) + if os.path.dirname(output_path) != "": + os.makedirs(os.path.dirname(output_path), exist_ok=True) + torchaudio.save(output_path, wav.type(torch.int16), sampling_rate) + print(">> wav file saved to:", output_path) + return output_path + else: + # 返回以符合Gradio的格式要求 + wav_data = wav.type(torch.int16) + wav_data = wav_data.numpy().T + return (sampling_rate, wav_data) + +if __name__ == "__main__": + prompt_wav = "examples/voice_01.wav" + text = '欢迎大家来体验indextts2,并给予我们意见与反馈,谢谢大家。' + + tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False) + tts.infer(audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True) diff --git a/indextts/infer_v2.py b/indextts/infer_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..3edcb853a902637662fc62fc8fa96cf11d4ed0ad --- /dev/null +++ b/indextts/infer_v2.py @@ -0,0 +1,694 @@ +import os +from subprocess import CalledProcessError + +import time +import librosa +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +import warnings + +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +from omegaconf import OmegaConf + +from indextts.gpt.model_v2 import UnifiedVoice +from indextts.utils.maskgct_utils import build_semantic_model, build_semantic_codec +from indextts.utils.checkpoint import load_checkpoint +from indextts.utils.front import TextNormalizer, TextTokenizer + +from indextts.s2mel.modules.commons import load_checkpoint2, MyModel +from indextts.s2mel.modules.bigvgan import bigvgan +from indextts.s2mel.modules.campplus.DTDNN import CAMPPlus +from indextts.s2mel.modules.audio import mel_spectrogram + +from transformers import AutoTokenizer +from modelscope import AutoModelForCausalLM +from huggingface_hub import hf_hub_download +import safetensors +from transformers import SeamlessM4TFeatureExtractor +import random +import torch.nn.functional as F + +class IndexTTS2: + def __init__( + self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, device=None, + use_cuda_kernel=None, + ): + """ + Args: + cfg_path (str): path to the config file. + model_dir (str): path to the model directory. + is_fp16 (bool): whether to use fp16. + device (str): device to use (e.g., 'cuda:0', 'cpu'). If None, it will be set automatically based on the availability of CUDA or MPS. + use_cuda_kernel (None | bool): whether to use BigVGan custom fused activation CUDA kernel, only for CUDA device. + """ + if device is not None: + self.device = device + self.is_fp16 = False if device == "cpu" else is_fp16 + self.use_cuda_kernel = use_cuda_kernel is not None and use_cuda_kernel and device.startswith("cuda") + elif torch.cuda.is_available(): + self.device = "cuda:0" + self.is_fp16 = is_fp16 + self.use_cuda_kernel = use_cuda_kernel is None or use_cuda_kernel + elif hasattr(torch, "mps") and torch.backends.mps.is_available(): + self.device = "mps" + self.is_fp16 = False # Use float16 on MPS is overhead than float32 + self.use_cuda_kernel = False + else: + self.device = "cpu" + self.is_fp16 = False + self.use_cuda_kernel = False + print(">> Be patient, it may take a while to run in CPU mode.") + + self.cfg = OmegaConf.load(cfg_path) + self.model_dir = model_dir + self.dtype = torch.float16 if self.is_fp16 else None + self.stop_mel_token = self.cfg.gpt.stop_mel_token + + self.qwen_emo = QwenEmotion(os.path.join(self.model_dir, self.cfg.qwen_emo_path)) + + self.gpt = UnifiedVoice(**self.cfg.gpt) + self.gpt_path = os.path.join(self.model_dir, self.cfg.gpt_checkpoint) + load_checkpoint(self.gpt, self.gpt_path) + self.gpt = self.gpt.to(self.device) + if self.is_fp16: + self.gpt.eval().half() + else: + self.gpt.eval() + print(">> GPT weights restored from:", self.gpt_path) + if self.is_fp16: + try: + import deepspeed + + use_deepspeed = True + except (ImportError, OSError, CalledProcessError) as e: + use_deepspeed = False + print(f">> DeepSpeed加载失败,回退到标准推理: {e}") + + self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True) + else: + self.gpt.post_init_gpt2_config(use_deepspeed=True, kv_cache=True, half=False) + + if self.use_cuda_kernel: + # preload the CUDA kernel for BigVGAN + try: + from indextts.BigVGAN.alias_free_activation.cuda import load + + anti_alias_activation_cuda = load.load() + print(">> Preload custom CUDA kernel for BigVGAN", anti_alias_activation_cuda) + except: + print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.") + self.use_cuda_kernel = False + + self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") + self.semantic_model, self.semantic_mean, self.semantic_std = build_semantic_model( + os.path.join(self.model_dir, self.cfg.w2v_stat)) + self.semantic_model = self.semantic_model.to(self.device) + self.semantic_model.eval() + self.semantic_mean = self.semantic_mean.to(self.device) + self.semantic_std = self.semantic_std.to(self.device) + + semantic_codec = build_semantic_codec(self.cfg.semantic_codec) + semantic_code_ckpt = hf_hub_download("amphion/MaskGCT", filename="semantic_codec/model.safetensors") + safetensors.torch.load_model(semantic_codec, semantic_code_ckpt) + self.semantic_codec = semantic_codec.to(self.device) + self.semantic_codec.eval() + print('>> semantic_codec weights restored from: {}'.format(semantic_code_ckpt)) + + s2mel_path = os.path.join(self.model_dir, self.cfg.s2mel_checkpoint) + s2mel = MyModel(self.cfg.s2mel, use_gpt_latent=True) + s2mel, _, _, _ = load_checkpoint2( + s2mel, + None, + s2mel_path, + load_only_params=True, + ignore_modules=[], + is_distributed=False, + ) + self.s2mel = s2mel.to(self.device) + self.s2mel.models['cfm'].estimator.setup_caches(max_batch_size=1, max_seq_length=8192) + self.s2mel.eval() + print(">> s2mel weights restored from:", s2mel_path) + + # load campplus_model + campplus_ckpt_path = hf_hub_download( + "funasr/campplus", filename="campplus_cn_common.bin" + ) + campplus_model = CAMPPlus(feat_dim=80, embedding_size=192) + campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu")) + self.campplus_model = campplus_model.to(self.device) + self.campplus_model.eval() + print(">> campplus_model weights restored from:", campplus_ckpt_path) + + bigvgan_name = self.cfg.vocoder.name + self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=False) + self.bigvgan = self.bigvgan.to(self.device) + self.bigvgan.remove_weight_norm() + self.bigvgan.eval() + print(">> bigvgan weights restored from:", bigvgan_name) + + self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"]) + self.normalizer = TextNormalizer() + self.normalizer.load() + print(">> TextNormalizer loaded") + self.tokenizer = TextTokenizer(self.bpe_path, self.normalizer) + print(">> bpe model loaded from:", self.bpe_path) + + emo_matrix = torch.load(os.path.join(self.model_dir, self.cfg.emo_matrix)) + self.emo_matrix = emo_matrix.to(self.device) + self.emo_num = list(self.cfg.emo_num) + + spk_matrix = torch.load(os.path.join(self.model_dir, self.cfg.spk_matrix)) + self.spk_matrix = spk_matrix.to(self.device) + + self.emo_matrix = torch.split(self.emo_matrix, self.emo_num) + self.spk_matrix = torch.split(self.spk_matrix, self.emo_num) + + mel_fn_args = { + "n_fft": self.cfg.s2mel['preprocess_params']['spect_params']['n_fft'], + "win_size": self.cfg.s2mel['preprocess_params']['spect_params']['win_length'], + "hop_size": self.cfg.s2mel['preprocess_params']['spect_params']['hop_length'], + "num_mels": self.cfg.s2mel['preprocess_params']['spect_params']['n_mels'], + "sampling_rate": self.cfg.s2mel["preprocess_params"]["sr"], + "fmin": self.cfg.s2mel['preprocess_params']['spect_params'].get('fmin', 0), + "fmax": None if self.cfg.s2mel['preprocess_params']['spect_params'].get('fmax', "None") == "None" else 8000, + "center": False + } + self.mel_fn = lambda x: mel_spectrogram(x, **mel_fn_args) + + # 缓存参考音频: + self.cache_spk_cond = None + self.cache_s2mel_style = None + self.cache_s2mel_prompt = None + self.cache_spk_audio_prompt = None + self.cache_emo_cond = None + self.cache_emo_audio_prompt = None + self.cache_mel = None + + # 进度引用显示(可选) + self.gr_progress = None + self.model_version = self.cfg.version if hasattr(self.cfg, "version") else None + + @torch.no_grad() + def get_emb(self, input_features, attention_mask): + vq_emb = self.semantic_model( + input_features=input_features, + attention_mask=attention_mask, + output_hidden_states=True, + ) + feat = vq_emb.hidden_states[17] # (B, T, C) + feat = (feat - self.semantic_mean) / self.semantic_std + return feat + + def remove_long_silence(self, codes: torch.Tensor, silent_token=52, max_consecutive=30): + """ + Shrink special tokens (silent_token and stop_mel_token) in codes + codes: [B, T] + """ + code_lens = [] + codes_list = [] + device = codes.device + dtype = codes.dtype + isfix = False + for i in range(0, codes.shape[0]): + code = codes[i] + if not torch.any(code == self.stop_mel_token).item(): + len_ = code.size(0) + else: + stop_mel_idx = (code == self.stop_mel_token).nonzero(as_tuple=False) + len_ = stop_mel_idx[0].item() if len(stop_mel_idx) > 0 else code.size(0) + + count = torch.sum(code == silent_token).item() + if count > max_consecutive: + # code = code.cpu().tolist() + ncode_idx = [] + n = 0 + for k in range(len_): + assert code[ + k] != self.stop_mel_token, f"stop_mel_token {self.stop_mel_token} should be shrinked here" + if code[k] != silent_token: + ncode_idx.append(k) + n = 0 + elif code[k] == silent_token and n < 10: + ncode_idx.append(k) + n += 1 + # if (k == 0 and code[k] == 52) or (code[k] == 52 and code[k-1] == 52): + # n += 1 + # new code + len_ = len(ncode_idx) + codes_list.append(code[ncode_idx]) + isfix = True + else: + # shrink to len_ + codes_list.append(code[:len_]) + code_lens.append(len_) + if isfix: + if len(codes_list) > 1: + codes = pad_sequence(codes_list, batch_first=True, padding_value=self.stop_mel_token) + else: + codes = codes_list[0].unsqueeze(0) + else: + # unchanged + pass + # clip codes to max length + max_len = max(code_lens) + if max_len < codes.shape[1]: + codes = codes[:, :max_len] + code_lens = torch.tensor(code_lens, dtype=torch.long, device=device) + return codes, code_lens + + def insert_interval_silence(self, wavs, sampling_rate=22050, interval_silence=200): + """ + Insert silences between sentences. + wavs: List[torch.tensor] + """ + + if not wavs or interval_silence <= 0: + return wavs + + # get channel_size + channel_size = wavs[0].size(0) + # get silence tensor + sil_dur = int(sampling_rate * interval_silence / 1000.0) + sil_tensor = torch.zeros(channel_size, sil_dur) + + wavs_list = [] + for i, wav in enumerate(wavs): + wavs_list.append(wav) + if i < len(wavs) - 1: + wavs_list.append(sil_tensor) + + return wavs_list + + def _set_gr_progress(self, value, desc): + if self.gr_progress is not None: + self.gr_progress(value, desc=desc) + + # 原始推理模式 + def infer(self, spk_audio_prompt, text, output_path, + emo_audio_prompt=None, emo_alpha=1.0, + emo_vector=None, + use_emo_text=False, emo_text=None, use_random=False, interval_silence=200, + verbose=False, max_text_tokens_per_sentence=120, **generation_kwargs): + print(">> start inference...") + self._set_gr_progress(0, "start inference...") + if verbose: + print(f"origin text:{text}, spk_audio_prompt:{spk_audio_prompt}," + f" emo_audio_prompt:{emo_audio_prompt}, emo_alpha:{emo_alpha}, " + f"emo_vector:{emo_vector}, use_emo_text:{use_emo_text}, " + f"emo_text:{emo_text}") + start_time = time.perf_counter() + + if use_emo_text: + emo_audio_prompt = None + emo_alpha = 1.0 + # assert emo_audio_prompt is None + # assert emo_alpha == 1.0 + if emo_text is None: + emo_text = text + emo_dict, content = self.qwen_emo.inference(emo_text) + print(emo_dict) + emo_vector = list(emo_dict.values()) + + if emo_vector is not None: + emo_audio_prompt = None + emo_alpha = 1.0 + # assert emo_audio_prompt is None + # assert emo_alpha == 1.0 + + if emo_audio_prompt is None: + emo_audio_prompt = spk_audio_prompt + emo_alpha = 1.0 + # assert emo_alpha == 1.0 + + # 如果参考音频改变了,才需要重新生成, 提升速度 + if self.cache_spk_cond is None or self.cache_spk_audio_prompt != spk_audio_prompt: + audio, sr = librosa.load(spk_audio_prompt) + audio = torch.tensor(audio).unsqueeze(0) + audio_22k = torchaudio.transforms.Resample(sr, 22050)(audio) + audio_16k = torchaudio.transforms.Resample(sr, 16000)(audio) + + inputs = self.extract_features(audio_16k, sampling_rate=16000, return_tensors="pt") + input_features = inputs["input_features"] + attention_mask = inputs["attention_mask"] + input_features = input_features.to(self.device) + attention_mask = attention_mask.to(self.device) + spk_cond_emb = self.get_emb(input_features, attention_mask) + + _, S_ref = self.semantic_codec.quantize(spk_cond_emb) + ref_mel = self.mel_fn(audio_22k.to(spk_cond_emb.device).float()) + ref_target_lengths = torch.LongTensor([ref_mel.size(2)]).to(ref_mel.device) + feat = torchaudio.compliance.kaldi.fbank(audio_16k.to(ref_mel.device), + num_mel_bins=80, + dither=0, + sample_frequency=16000) + feat = feat - feat.mean(dim=0, keepdim=True) # feat2另外一个滤波器能量组特征[922, 80] + style = self.campplus_model(feat.unsqueeze(0)) # 参考音频的全局style2[1,192] + + prompt_condition = self.s2mel.models['length_regulator'](S_ref, + ylens=ref_target_lengths, + n_quantizers=3, + f0=None)[0] + + self.cache_spk_cond = spk_cond_emb + self.cache_s2mel_style = style + self.cache_s2mel_prompt = prompt_condition + self.cache_spk_audio_prompt = spk_audio_prompt + self.cache_mel = ref_mel + else: + style = self.cache_s2mel_style + prompt_condition = self.cache_s2mel_prompt + spk_cond_emb = self.cache_spk_cond + ref_mel = self.cache_mel + + if emo_vector is not None: + weight_vector = torch.tensor(emo_vector).to(self.device) + if use_random: + random_index = [random.randint(0, x - 1) for x in self.emo_num] + else: + random_index = [find_most_similar_cosine(style, tmp) for tmp in self.spk_matrix] + + emo_matrix = [tmp[index].unsqueeze(0) for index, tmp in zip(random_index, self.emo_matrix)] + emo_matrix = torch.cat(emo_matrix, 0) + emovec_mat = weight_vector.unsqueeze(1) * emo_matrix + emovec_mat = torch.sum(emovec_mat, 0) + emovec_mat = emovec_mat.unsqueeze(0) + + if self.cache_emo_cond is None or self.cache_emo_audio_prompt != emo_audio_prompt: + emo_audio, _ = librosa.load(emo_audio_prompt, sr=16000) + emo_inputs = self.extract_features(emo_audio, sampling_rate=16000, return_tensors="pt") + emo_input_features = emo_inputs["input_features"] + emo_attention_mask = emo_inputs["attention_mask"] + emo_input_features = emo_input_features.to(self.device) + emo_attention_mask = emo_attention_mask.to(self.device) + emo_cond_emb = self.get_emb(emo_input_features, emo_attention_mask) + + self.cache_emo_cond = emo_cond_emb + self.cache_emo_audio_prompt = emo_audio_prompt + else: + emo_cond_emb = self.cache_emo_cond + + self._set_gr_progress(0.1, "text processing...") + text_tokens_list = self.tokenizer.tokenize(text) + sentences = self.tokenizer.split_sentences(text_tokens_list, max_text_tokens_per_sentence) + if verbose: + print("text_tokens_list:", text_tokens_list) + print("sentences count:", len(sentences)) + print("max_text_tokens_per_sentence:", max_text_tokens_per_sentence) + print(*sentences, sep="\n") + do_sample = generation_kwargs.pop("do_sample", True) + top_p = generation_kwargs.pop("top_p", 0.8) + top_k = generation_kwargs.pop("top_k", 30) + temperature = generation_kwargs.pop("temperature", 0.8) + autoregressive_batch_size = 1 + length_penalty = generation_kwargs.pop("length_penalty", 0.0) + num_beams = generation_kwargs.pop("num_beams", 3) + repetition_penalty = generation_kwargs.pop("repetition_penalty", 10.0) + max_mel_tokens = generation_kwargs.pop("max_mel_tokens", 1500) + sampling_rate = 22050 + + wavs = [] + gpt_gen_time = 0 + gpt_forward_time = 0 + s2mel_time = 0 + bigvgan_time = 0 + progress = 0 + has_warned = False + for sent in sentences: + text_tokens = self.tokenizer.convert_tokens_to_ids(sent) + text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=self.device).unsqueeze(0) + if verbose: + print(text_tokens) + print(f"text_tokens shape: {text_tokens.shape}, text_tokens type: {text_tokens.dtype}") + # debug tokenizer + text_token_syms = self.tokenizer.convert_ids_to_tokens(text_tokens[0].tolist()) + print("text_token_syms is same as sentence tokens", text_token_syms == sent) + + m_start_time = time.perf_counter() + with torch.no_grad(): + with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): + emovec = self.gpt.merge_emovec( + spk_cond_emb, + emo_cond_emb, + torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device), + torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device), + alpha=emo_alpha + ) + + if emo_vector is not None: + emovec = emovec_mat + (1 - torch.sum(weight_vector)) * emovec + # emovec = emovec_mat + + codes, speech_conditioning_latent = self.gpt.inference_speech( + spk_cond_emb, + text_tokens, + emo_cond_emb, + cond_lengths=torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device), + emo_cond_lengths=torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device), + emo_vec=emovec, + do_sample=True, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=autoregressive_batch_size, + length_penalty=length_penalty, + num_beams=num_beams, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **generation_kwargs + ) + + gpt_gen_time += time.perf_counter() - m_start_time + if not has_warned and (codes[:, -1] != self.stop_mel_token).any(): + warnings.warn( + f"WARN: generation stopped due to exceeding `max_mel_tokens` ({max_mel_tokens}). " + f"Input text tokens: {text_tokens.shape[1]}. " + f"Consider reducing `max_text_tokens_per_sentence`({max_text_tokens_per_sentence}) or increasing `max_mel_tokens`.", + category=RuntimeWarning + ) + has_warned = True + + code_lens = torch.tensor([codes.shape[-1]], device=codes.device, dtype=codes.dtype) + # if verbose: + # print(codes, type(codes)) + # print(f"codes shape: {codes.shape}, codes type: {codes.dtype}") + # print(f"code len: {code_lens}") + + code_lens = [] + for code in codes: + if self.stop_mel_token not in code: + code_lens.append(len(code)) + code_len = len(code) + else: + len_ = (code == self.stop_mel_token).nonzero(as_tuple=False)[0] + 1 + code_len = len_ - 1 + code_lens.append(code_len) + codes = codes[:, :code_len] + code_lens = torch.LongTensor(code_lens) + code_lens = code_lens.to(self.device) + if verbose: + print(codes, type(codes)) + print(f"fix codes shape: {codes.shape}, codes type: {codes.dtype}") + print(f"code len: {code_lens}") + + m_start_time = time.perf_counter() + use_speed = torch.zeros(spk_cond_emb.size(0)).to(spk_cond_emb.device).long() + with torch.amp.autocast(text_tokens.device.type, enabled=self.dtype is not None, dtype=self.dtype): + latent = self.gpt( + speech_conditioning_latent, + text_tokens, + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), + codes, + torch.tensor([codes.shape[-1]], device=text_tokens.device), + emo_cond_emb, + cond_mel_lengths=torch.tensor([spk_cond_emb.shape[-1]], device=text_tokens.device), + emo_cond_mel_lengths=torch.tensor([emo_cond_emb.shape[-1]], device=text_tokens.device), + emo_vec=emovec, + use_speed=use_speed, + ) + gpt_forward_time += time.perf_counter() - m_start_time + + dtype = None + with torch.amp.autocast(text_tokens.device.type, enabled=dtype is not None, dtype=dtype): + m_start_time = time.perf_counter() + diffusion_steps = 25 + inference_cfg_rate = 0.7 + latent = self.s2mel.models['gpt_layer'](latent) + S_infer = self.semantic_codec.quantizer.vq2emb(codes.unsqueeze(1)) + S_infer = S_infer.transpose(1, 2) + S_infer = S_infer + latent + target_lengths = (code_lens * 1.72).long() + + cond = self.s2mel.models['length_regulator'](S_infer, + ylens=target_lengths, + n_quantizers=3, + f0=None)[0] + cat_condition = torch.cat([prompt_condition, cond], dim=1) + vc_target = self.s2mel.models['cfm'].inference(cat_condition, + torch.LongTensor([cat_condition.size(1)]).to( + cond.device), + ref_mel, style, None, diffusion_steps, + inference_cfg_rate=inference_cfg_rate) + vc_target = vc_target[:, :, ref_mel.size(-1):] + s2mel_time += time.perf_counter() - m_start_time + + m_start_time = time.perf_counter() + wav = self.bigvgan(vc_target.float()).squeeze().unsqueeze(0) + print(wav.shape) + bigvgan_time += time.perf_counter() - m_start_time + wav = wav.squeeze(1) + + wav = torch.clamp(32767 * wav, -32767.0, 32767.0) + if verbose: + print(f"wav shape: {wav.shape}", "min:", wav.min(), "max:", wav.max()) + # wavs.append(wav[:, :-512]) + wavs.append(wav.cpu()) # to cpu before saving + end_time = time.perf_counter() + self._set_gr_progress(0.9, "save audio...") + wavs = self.insert_interval_silence(wavs, sampling_rate=sampling_rate, interval_silence=interval_silence) + wav = torch.cat(wavs, dim=1) + wav_length = wav.shape[-1] / sampling_rate + print(f">> gpt_gen_time: {gpt_gen_time:.2f} seconds") + print(f">> gpt_forward_time: {gpt_forward_time:.2f} seconds") + print(f">> s2mel_time: {s2mel_time:.2f} seconds") + print(f">> bigvgan_time: {bigvgan_time:.2f} seconds") + print(f">> Total inference time: {end_time - start_time:.2f} seconds") + print(f">> Generated audio length: {wav_length:.2f} seconds") + print(f">> RTF: {(end_time - start_time) / wav_length:.4f}") + + # save audio + wav = wav.cpu() # to cpu + if output_path: + # 直接保存音频到指定路径中 + if os.path.isfile(output_path): + os.remove(output_path) + print(">> remove old wav file:", output_path) + if os.path.dirname(output_path) != "": + os.makedirs(os.path.dirname(output_path), exist_ok=True) + torchaudio.save(output_path, wav.type(torch.int16), sampling_rate) + print(">> wav file saved to:", output_path) + return output_path + else: + # 返回以符合Gradio的格式要求 + wav_data = wav.type(torch.int16) + wav_data = wav_data.numpy().T + return (sampling_rate, wav_data) + + +def find_most_similar_cosine(query_vector, matrix): + query_vector = query_vector.float() + matrix = matrix.float() + + similarities = F.cosine_similarity(query_vector, matrix, dim=1) + most_similar_index = torch.argmax(similarities) + return most_similar_index + +class QwenEmotion: + def __init__(self, model_dir): + self.model_dir = model_dir + self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) + self.model = AutoModelForCausalLM.from_pretrained( + self.model_dir, + torch_dtype="float16", # "auto" + device_map="auto" + ) + self.prompt = "文本情感分类" + self.convert_dict = { + "愤怒": "angry", + "高兴": "happy", + "恐惧": "fear", + "反感": "hate", + "悲伤": "sad", + "低落": "low", + "惊讶": "surprise", + "自然": "neutral", + } + self.backup_dict = {"happy": 0, "angry": 0, "sad": 0, "fear": 0, "hate": 0, "low": 0, "surprise": 0, + "neutral": 1.0} + self.max_score = 1.2 + self.min_score = 0.0 + + def convert(self, content): + content = content.replace("\n", " ") + content = content.replace(" ", "") + content = content.replace("{", "") + content = content.replace("}", "") + content = content.replace('"', "") + parts = content.strip().split(',') + print(parts) + parts_dict = {} + desired_order = ["高兴", "愤怒", "悲伤", "恐惧", "反感", "低落", "惊讶", "自然"] + for part in parts: + key_value = part.strip().split(':') + if len(key_value) == 2: + parts_dict[key_value[0].strip()] = part + # 按照期望顺序重新排列 + ordered_parts = [parts_dict[key] for key in desired_order if key in parts_dict] + parts = ordered_parts + if len(parts) != len(self.convert_dict): + return self.backup_dict + + emotion_dict = {} + for part in parts: + key_value = part.strip().split(':') + if len(key_value) == 2: + try: + key = self.convert_dict[key_value[0].strip()] + value = float(key_value[1].strip()) + value = max(self.min_score, min(self.max_score, value)) + emotion_dict[key] = value + except Exception: + continue + + for key in self.backup_dict: + if key not in emotion_dict: + emotion_dict[key] = 0.0 + + if sum(emotion_dict.values()) <= 0: + return self.backup_dict + + return emotion_dict + + def inference(self, text_input): + start = time.time() + messages = [ + {"role": "system", "content": f"{self.prompt}"}, + {"role": "user", "content": f"{text_input}"} + ] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) + + # conduct text completion + generated_ids = self.model.generate( + **model_inputs, + max_new_tokens=32768, + pad_token_id=self.tokenizer.eos_token_id + ) + output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() + + # parsing thinking content + try: + # rindex finding 151668 () + index = len(output_ids) - output_ids[::-1].index(151668) + except ValueError: + index = 0 + + content = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n") + emotion_dict = self.convert(content) + return emotion_dict, content + + +if __name__ == "__main__": + prompt_wav = "examples/voice_01.wav" + text = '欢迎大家来体验indextts2,并给予我们意见与反馈,谢谢大家。' + + tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_cuda_kernel=False) + tts.infer(spk_audio_prompt=prompt_wav, text=text, output_path="gen.wav", verbose=True) diff --git a/indextts/s2mel/dac/__init__.py b/indextts/s2mel/dac/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51205ef6ded9c6735a988b76008e0f6bdce8e215 --- /dev/null +++ b/indextts/s2mel/dac/__init__.py @@ -0,0 +1,16 @@ +__version__ = "1.0.0" + +# preserved here for legacy reasons +__model_version__ = "latest" + +import audiotools + +audiotools.ml.BaseModel.INTERN += ["dac.**"] +audiotools.ml.BaseModel.EXTERN += ["einops"] + + +from . import nn +from . import model +from . import utils +from .model import DAC +from .model import DACFile diff --git a/indextts/s2mel/dac/__main__.py b/indextts/s2mel/dac/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa8d15307997663f8143669c2bd56e0889cb021 --- /dev/null +++ b/indextts/s2mel/dac/__main__.py @@ -0,0 +1,36 @@ +import sys + +import argbind + +from dac.utils import download +from dac.utils.decode import decode +from dac.utils.encode import encode + +STAGES = ["encode", "decode", "download"] + + +def run(stage: str): + """Run stages. + + Parameters + ---------- + stage : str + Stage to run + """ + if stage not in STAGES: + raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") + stage_fn = globals()[stage] + + if stage == "download": + stage_fn() + return + + stage_fn() + + +if __name__ == "__main__": + group = sys.argv.pop(1) + args = argbind.parse_args(group=group) + + with argbind.scope(args): + run(group) diff --git a/indextts/s2mel/dac/model/__init__.py b/indextts/s2mel/dac/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02a75b7ad6028f5c41b6a8285b0257d4c23bdfcf --- /dev/null +++ b/indextts/s2mel/dac/model/__init__.py @@ -0,0 +1,4 @@ +from .base import CodecMixin +from .base import DACFile +from .dac import DAC +from .discriminator import Discriminator diff --git a/indextts/s2mel/dac/model/base.py b/indextts/s2mel/dac/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..546b3cb7092d6bd1837ec780228d2a5b3e01fe8d --- /dev/null +++ b/indextts/s2mel/dac/model/base.py @@ -0,0 +1,294 @@ +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import tqdm +from audiotools import AudioSignal +from torch import nn + +SUPPORTED_VERSIONS = ["1.0.0"] + + +@dataclass +class DACFile: + codes: torch.Tensor + + # Metadata + chunk_length: int + original_length: int + input_db: float + channels: int + sample_rate: int + padding: bool + dac_version: str + + def save(self, path): + artifacts = { + "codes": self.codes.numpy().astype(np.uint16), + "metadata": { + "input_db": self.input_db.numpy().astype(np.float32), + "original_length": self.original_length, + "sample_rate": self.sample_rate, + "chunk_length": self.chunk_length, + "channels": self.channels, + "padding": self.padding, + "dac_version": SUPPORTED_VERSIONS[-1], + }, + } + path = Path(path).with_suffix(".dac") + with open(path, "wb") as f: + np.save(f, artifacts) + return path + + @classmethod + def load(cls, path): + artifacts = np.load(path, allow_pickle=True)[()] + codes = torch.from_numpy(artifacts["codes"].astype(int)) + if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: + raise RuntimeError( + f"Given file {path} can't be loaded with this version of descript-audio-codec." + ) + return cls(codes=codes, **artifacts["metadata"]) + + +class CodecMixin: + @property + def padding(self): + if not hasattr(self, "_padding"): + self._padding = True + return self._padding + + @padding.setter + def padding(self, value): + assert isinstance(value, bool) + + layers = [ + l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) + ] + + for layer in layers: + if value: + if hasattr(layer, "original_padding"): + layer.padding = layer.original_padding + else: + layer.original_padding = layer.padding + layer.padding = tuple(0 for _ in range(len(layer.padding))) + + self._padding = value + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + layers.append(layer) + + for layer in reversed(layers): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.ConvTranspose1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.Conv1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.ConvTranspose1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L + + @torch.no_grad() + def compress( + self, + audio_path_or_signal: Union[str, Path, AudioSignal], + win_duration: float = 1.0, + verbose: bool = False, + normalize_db: float = -16, + n_quantizers: int = None, + ) -> DACFile: + """Processes an audio signal from a file or AudioSignal object into + discrete codes. This function processes the signal in short windows, + using constant GPU memory. + + Parameters + ---------- + audio_path_or_signal : Union[str, Path, AudioSignal] + audio signal to reconstruct + win_duration : float, optional + window duration in seconds, by default 5.0 + verbose : bool, optional + by default False + normalize_db : float, optional + normalize db, by default -16 + + Returns + ------- + DACFile + Object containing compressed codes and metadata + required for decompression + """ + audio_signal = audio_path_or_signal + if isinstance(audio_signal, (str, Path)): + audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) + + self.eval() + original_padding = self.padding + original_device = audio_signal.device + + audio_signal = audio_signal.clone() + original_sr = audio_signal.sample_rate + + resample_fn = audio_signal.resample + loudness_fn = audio_signal.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if audio_signal.signal_duration >= 10 * 60 * 60: + resample_fn = audio_signal.ffmpeg_resample + loudness_fn = audio_signal.ffmpeg_loudness + + original_length = audio_signal.signal_length + resample_fn(self.sample_rate) + input_db = loudness_fn() + + if normalize_db is not None: + audio_signal.normalize(normalize_db) + audio_signal.ensure_max_of_audio() + + nb, nac, nt = audio_signal.audio_data.shape + audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) + win_duration = ( + audio_signal.signal_duration if win_duration is None else win_duration + ) + + if audio_signal.signal_duration <= win_duration: + # Unchunked compression (used if signal length < win duration) + self.padding = True + n_samples = nt + hop = nt + else: + # Chunked inference + self.padding = False + # Zero-pad signal on either side by the delay + audio_signal.zero_pad(self.delay, self.delay) + n_samples = int(win_duration * self.sample_rate) + # Round n_samples to nearest hop length multiple + n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) + hop = self.get_output_length(n_samples) + + codes = [] + range_fn = range if not verbose else tqdm.trange + + for i in range_fn(0, nt, hop): + x = audio_signal[..., i : i + n_samples] + x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + + audio_data = x.audio_data.to(self.device) + audio_data = self.preprocess(audio_data, self.sample_rate) + _, c, _, _, _ = self.encode(audio_data, n_quantizers) + codes.append(c.to(original_device)) + chunk_length = c.shape[-1] + + codes = torch.cat(codes, dim=-1) + + dac_file = DACFile( + codes=codes, + chunk_length=chunk_length, + original_length=original_length, + input_db=input_db, + channels=nac, + sample_rate=original_sr, + padding=self.padding, + dac_version=SUPPORTED_VERSIONS[-1], + ) + + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + + self.padding = original_padding + return dac_file + + @torch.no_grad() + def decompress( + self, + obj: Union[str, Path, DACFile], + verbose: bool = False, + ) -> AudioSignal: + """Reconstruct audio from a given .dac file + + Parameters + ---------- + obj : Union[str, Path, DACFile] + .dac file location or corresponding DACFile object. + verbose : bool, optional + Prints progress if True, by default False + + Returns + ------- + AudioSignal + Object with the reconstructed audio + """ + self.eval() + if isinstance(obj, (str, Path)): + obj = DACFile.load(obj) + + original_padding = self.padding + self.padding = obj.padding + + range_fn = range if not verbose else tqdm.trange + codes = obj.codes + original_device = codes.device + chunk_length = obj.chunk_length + recons = [] + + for i in range_fn(0, codes.shape[-1], chunk_length): + c = codes[..., i : i + chunk_length].to(self.device) + z = self.quantizer.from_codes(c)[0] + r = self.decode(z) + recons.append(r.to(original_device)) + + recons = torch.cat(recons, dim=-1) + recons = AudioSignal(recons, self.sample_rate) + + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness + + recons.normalize(obj.input_db) + resample_fn(obj.sample_rate) + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape( + -1, obj.channels, obj.original_length + ) + + self.padding = original_padding + return recons diff --git a/indextts/s2mel/dac/model/dac.py b/indextts/s2mel/dac/model/dac.py new file mode 100644 index 0000000000000000000000000000000000000000..5224dbadcc25c40404326c963777ad99ff8b6440 --- /dev/null +++ b/indextts/s2mel/dac/model/dac.py @@ -0,0 +1,400 @@ +import math +from typing import List +from typing import Union + +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.ml import BaseModel +from torch import nn + +from .base import CodecMixin +from indextts.s2mel.dac.nn.layers import Snake1d +from indextts.s2mel.dac.nn.layers import WNConv1d +from indextts.s2mel.dac.nn.layers import WNConvTranspose1d +from indextts.s2mel.dac.nn.quantize import ResidualVectorQuantize +from .encodec import SConv1d, SConvTranspose1d, SLSTM + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False): + super().__init__() + conv1d_type = SConv1d# if causal else WNConv1d + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + conv1d_type(dim, dim, kernel_size=7, dilation=dilation, padding=pad, causal=causal, norm='weight_norm'), + Snake1d(dim), + conv1d_type(dim, dim, kernel_size=1, causal=causal, norm='weight_norm'), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1, causal: bool = False): + super().__init__() + conv1d_type = SConv1d# if causal else WNConv1d + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1, causal=causal), + ResidualUnit(dim // 2, dilation=3, causal=causal), + ResidualUnit(dim // 2, dilation=9, causal=causal), + Snake1d(dim // 2), + conv1d_type( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + causal=causal, + norm='weight_norm', + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 64, + causal: bool = False, + lstm: int = 2, + ): + super().__init__() + conv1d_type = SConv1d# if causal else WNConv1d + # Create first convolution + self.block = [conv1d_type(1, d_model, kernel_size=7, padding=3, causal=causal, norm='weight_norm')] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride, causal=causal)] + + # Add LSTM if needed + self.use_lstm = lstm + if lstm: + self.block += [SLSTM(d_model, lstm)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + conv1d_type(d_model, d_latent, kernel_size=3, padding=1, causal=causal, norm='weight_norm'), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + def reset_cache(self): + # recursively find all submodules named SConv1d in self.block and use their reset_cache method + def reset_cache(m): + if isinstance(m, SConv1d) or isinstance(m, SLSTM): + m.reset_cache() + return + for child in m.children(): + reset_cache(child) + + reset_cache(self.block) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, causal: bool = False): + super().__init__() + conv1d_type = SConvTranspose1d #if causal else WNConvTranspose1d + self.block = nn.Sequential( + Snake1d(input_dim), + conv1d_type( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + causal=causal, + norm='weight_norm' + ), + ResidualUnit(output_dim, dilation=1, causal=causal), + ResidualUnit(output_dim, dilation=3, causal=causal), + ResidualUnit(output_dim, dilation=9, causal=causal), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + causal: bool = False, + lstm: int = 2, + ): + super().__init__() + conv1d_type = SConv1d# if causal else WNConv1d + # Add first conv layer + layers = [conv1d_type(input_channel, channels, kernel_size=7, padding=3, causal=causal, norm='weight_norm')] + + if lstm: + layers += [SLSTM(channels, num_layers=lstm)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride, causal=causal)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + conv1d_type(output_dim, d_out, kernel_size=7, padding=3, causal=causal, norm='weight_norm'), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DAC(BaseModel, CodecMixin): + def __init__( + self, + encoder_dim: int = 64, + encoder_rates: List[int] = [2, 4, 8, 8], + latent_dim: int = None, + decoder_dim: int = 1536, + decoder_rates: List[int] = [8, 8, 4, 2], + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: bool = False, + sample_rate: int = 44100, + lstm: int = 2, + causal: bool = False, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim, causal=causal, lstm=lstm) + + self.n_codebooks = n_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + lstm=lstm, + causal=causal, + ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + n_quantizers: int = None, + ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + z = self.encoder(audio_data) + z, codes, latents, commitment_loss, codebook_loss = self.quantizer( + z, n_quantizers + ) + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: torch.Tensor): + """Decode given latent codes and return audio data + + Parameters + ---------- + z : Tensor[B x D x T] + Quantized continuous representation of input + length : int, optional + Number of samples in output audio, by default None + + Returns + ------- + dict + A dictionary with the following keys: + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + return self.decoder(z) + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = None, + n_quantizers: int = None, + ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + z, codes, latents, commitment_loss, codebook_loss = self.encode( + audio_data, n_quantizers + ) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + +if __name__ == "__main__": + import numpy as np + from functools import partial + + model = DAC().to("cpu") + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + fn = lambda o, p: o + f" {p/1e6:<.3f}M params." + setattr(m, "extra_repr", partial(fn, o=o, p=p)) + print(model) + print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) + + length = 88200 * 2 + x = torch.randn(1, 1, length).to(model.device) + x.requires_grad_(True) + x.retain_grad() + + # Make a forward pass + out = model(x)["audio"] + print("Input shape:", x.shape) + print("Output shape:", out.shape) + + # Create gradient variable + grad = torch.zeros_like(out) + grad[:, :, grad.shape[-1] // 2] = 1 + + # Make a backward pass + out.backward(grad) + + # Check non-zero values + gradmap = x.grad.squeeze(0) + gradmap = (gradmap != 0).sum(0) # sum across features + rf = (gradmap != 0).sum() + + print(f"Receptive field: {rf.item()}") + + x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) + model.decompress(model.compress(x, verbose=True), verbose=True) diff --git a/indextts/s2mel/dac/model/discriminator.py b/indextts/s2mel/dac/model/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..9d9018d70b29786c7392146dc113a4c371ba9ef9 --- /dev/null +++ b/indextts/s2mel/dac/model/discriminator.py @@ -0,0 +1,228 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import ml +from audiotools import STFTParams +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv1d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +def WNConv2d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv2d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +class MPD(nn.Module): + def __init__(self, period): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False + ) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 44100): + super().__init__() + self.convs = nn.ModuleList( + [ + WNConv1d(1, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ] + ) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + x = AudioSignal(x, self.sample_rate) + x.resample(self.sample_rate // self.rate) + x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = torch.view_as_real(x.stft()) + x = rearrange(x, "b 1 f t c -> (b 1) c t f") + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class Discriminator(nn.Module): + def __init__( + self, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super().__init__() + discs = [] + discs += [MPD(p) for p in periods] + discs += [MSD(r, sample_rate=sample_rate) for r in rates] + discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + + +if __name__ == "__main__": + disc = Discriminator() + x = torch.zeros(1, 1, 44100) + results = disc(x) + for i, result in enumerate(results): + print(f"disc{i}") + for i, r in enumerate(result): + print(r.shape, r.mean(), r.min(), r.max()) + print() diff --git a/indextts/s2mel/dac/model/encodec.py b/indextts/s2mel/dac/model/encodec.py new file mode 100644 index 0000000000000000000000000000000000000000..8bbc42e15c2fa8fa558c6d45ff8c6f871667de04 --- /dev/null +++ b/indextts/s2mel/dac/model/encodec.py @@ -0,0 +1,320 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Convolutional layers wrappers and utilities.""" + +import math +import typing as tp +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +import typing as tp + +import einops + + +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = einops.rearrange(x, 'b ... t -> b t ...') + x = super().forward(x) + x = einops.rearrange(x, 'b t ... -> b ... t') + return + + +CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', + 'time_layer_norm', 'layer_norm', 'time_group_norm']) + + +def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == 'weight_norm': + return weight_norm(module) + elif norm == 'spectral_norm': + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == 'layer_norm': + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == 'time_group_norm': + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`. + """ + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class SConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, dilation: int = 1, + groups: int = 1, bias: bool = True, causal: bool = False, + norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = 'reflect', **kwargs): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1' + f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') + self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, + dilation=dilation, groups=groups, bias=bias, causal=causal, + norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.pad_mode = pad_mode + + self.cache_enabled = False + + def reset_cache(self): + """Reset the cache when starting a new stream.""" + self.cache = None + self.cache_enabled = True + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + + if self.causal: + # Left padding for causal + if self.cache_enabled and self.cache is not None: + # Concatenate the cache (previous inputs) with the new input for streaming + x = torch.cat([self.cache, x], dim=2) + else: + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + + # Store the most recent input frames for future cache use + if self.cache_enabled: + if self.cache is None: + # Initialize cache with zeros (at the start of streaming) + self.cache = torch.zeros(B, C, kernel_size - 1, device=x.device) + # Update the cache by storing the latest input frames + if kernel_size > 1: + self.cache = x[:, :, -kernel_size + 1:].detach() # Only store the necessary frames + + return self.conv(x) + + + +class SConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, causal: bool = False, + norm: str = 'none', trim_right_ratio: float = 1., + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, + causal=causal, norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert self.causal or self.trim_right_ratio == 1., \ + "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y + +class SLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers) + self.hidden = None + self.cache_enabled = False + + def forward(self, x): + x = x.permute(2, 0, 1) + if self.training or not self.cache_enabled: + y, _ = self.lstm(x) + else: + y, self.hidden = self.lstm(x, self.hidden) + if self.skip: + y = y + x + y = y.permute(1, 2, 0) + return y + + def reset_cache(self): + self.hidden = None + self.cache_enabled = True \ No newline at end of file diff --git a/indextts/s2mel/dac/nn/__init__.py b/indextts/s2mel/dac/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6718c8b1a3d36c31655b030f4c515a144cde4db7 --- /dev/null +++ b/indextts/s2mel/dac/nn/__init__.py @@ -0,0 +1,3 @@ +from . import layers +from . import loss +from . import quantize diff --git a/indextts/s2mel/dac/nn/layers.py b/indextts/s2mel/dac/nn/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94 --- /dev/null +++ b/indextts/s2mel/dac/nn/layers.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/indextts/s2mel/dac/nn/loss.py b/indextts/s2mel/dac/nn/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb3dd6ce08a7a24f18f941eeb5b68fe9461e86b --- /dev/null +++ b/indextts/s2mel/dac/nn/loss.py @@ -0,0 +1,368 @@ +import typing +from typing import List + +import torch +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import STFTParams +from torch import nn + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class GANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature diff --git a/indextts/s2mel/dac/nn/quantize.py b/indextts/s2mel/dac/nn/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..5d11ed2896b69fd90d058e6751f23652db8b15e3 --- /dev/null +++ b/indextts/s2mel/dac/nn/quantize.py @@ -0,0 +1,339 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from indextts.s2mel.dac.nn.layers import WNConv1d + +class VectorQuantizeLegacy(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + removed in-out projection + """ + + def __init__(self, input_dim: int, codebook_size: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook = nn.Embedding(codebook_size, input_dim) + + def forward(self, z, z_mask=None): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + z_e = z + z_q, indices = self.decode_latents(z) + + if z_mask is not None: + commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum() + codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum() + else: + commitment_loss = F.mse_loss(z_e, z_q.detach()) + codebook_loss = F.mse_loss(z_q, z_e.detach()) + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + return z_q, indices, z_e, commitment_loss, codebook_loss + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z, z_mask=None): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + if z_mask is not None: + commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum() + codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum() + else: + commitment_loss = F.mse_loss(z_e, z_q.detach()) + codebook_loss = F.mse_loss(z_q, z_e.detach()) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/indextts/s2mel/dac/utils/__init__.py b/indextts/s2mel/dac/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5245758d0517df36847a376251c19927e0380077 --- /dev/null +++ b/indextts/s2mel/dac/utils/__init__.py @@ -0,0 +1,123 @@ +from pathlib import Path + +import argbind +from audiotools import ml + +import indextts.s2mel.dac as dac + +DAC = dac.model.DAC +Accelerator = ml.Accelerator + +__MODEL_LATEST_TAGS__ = { + ("44khz", "8kbps"): "0.0.1", + ("24khz", "8kbps"): "0.0.4", + ("16khz", "8kbps"): "0.0.5", + ("44khz", "16kbps"): "1.0.0", +} + +__MODEL_URLS__ = { + ( + "44khz", + "0.0.1", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", + ( + "24khz", + "0.0.4", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", + ( + "16khz", + "0.0.5", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", + ( + "44khz", + "1.0.0", + "16kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", +} + + +@argbind.bind(group="download", positional=True, without_prefix=True) +def download( + model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" +): + """ + Function that downloads the weights file from URL if a local cache is not found. + + Parameters + ---------- + model_type : str + The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + Only 44khz model supports 16kbps. + tag : str + The tag of the model to download. Defaults to "latest". + + Returns + ------- + Path + Directory path required to load model via audiotools. + """ + model_type = model_type.lower() + tag = tag.lower() + + assert model_type in [ + "44khz", + "24khz", + "16khz", + ], "model_type must be one of '44khz', '24khz', or '16khz'" + + assert model_bitrate in [ + "8kbps", + "16kbps", + ], "model_bitrate must be one of '8kbps', or '16kbps'" + + if tag == "latest": + tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] + + download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) + + if download_link is None: + raise ValueError( + f"Could not find model with tag {tag} and model type {model_type}" + ) + + local_path = ( + Path.home() + / ".cache" + / "descript" + / "dac" + / f"weights_{model_type}_{model_bitrate}_{tag}.pth" + ) + if not local_path.exists(): + local_path.parent.mkdir(parents=True, exist_ok=True) + + # Download the model + import requests + + response = requests.get(download_link) + + if response.status_code != 200: + raise ValueError( + f"Could not download model. Received response code {response.status_code}" + ) + local_path.write_bytes(response.content) + + return local_path + + +def load_model( + model_type: str = "44khz", + model_bitrate: str = "8kbps", + tag: str = "latest", + load_path: str = None, +): + if not load_path: + load_path = download( + model_type=model_type, model_bitrate=model_bitrate, tag=tag + ) + generator = DAC.load(load_path) + return generator diff --git a/indextts/s2mel/dac/utils/decode.py b/indextts/s2mel/dac/utils/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..08d44e8453ec4fa3433c2a9952d1a4da15315939 --- /dev/null +++ b/indextts/s2mel/dac/utils/decode.py @@ -0,0 +1,95 @@ +import warnings +from pathlib import Path + +import argbind +import numpy as np +import torch +from audiotools import AudioSignal +from tqdm import tqdm + +from dac import DACFile +from dac.utils import load_model + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="decode", positional=True, without_prefix=True) +@torch.inference_mode() +@torch.no_grad() +def decode( + input: str, + output: str = "", + weights_path: str = "", + model_tag: str = "latest", + model_bitrate: str = "8kbps", + device: str = "cuda", + model_type: str = "44khz", + verbose: bool = False, +): + """Decode audio from codes. + + Parameters + ---------- + input : str + Path to input directory or file + output : str, optional + Path to output directory, by default "". + If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + device : str, optional + Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, + ) + generator.to(device) + generator.eval() + + # Find all .dac files in input directory + _input = Path(input) + input_files = list(_input.glob("**/*.dac")) + + # If input is a .dac file, add it to the list + if _input.suffix == ".dac": + input_files.append(_input) + + # Create output directory + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(input_files)), desc=f"Decoding files"): + # Load file + artifact = DACFile.load(input_files[i]) + + # Reconstruct audio from codes + recons = generator.decompress(artifact, verbose=verbose) + + # Compute output path + relative_path = input_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = input_files[i] + output_name = relative_path.with_suffix(".wav").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Write to file + recons.write(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + decode() diff --git a/indextts/s2mel/dac/utils/encode.py b/indextts/s2mel/dac/utils/encode.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3f6f44b3c210f485da1b1726b85494ff5e7804 --- /dev/null +++ b/indextts/s2mel/dac/utils/encode.py @@ -0,0 +1,94 @@ +import math +import warnings +from pathlib import Path + +import argbind +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.core import util +from tqdm import tqdm + +from dac.utils import load_model + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="encode", positional=True, without_prefix=True) +@torch.inference_mode() +@torch.no_grad() +def encode( + input: str, + output: str = "", + weights_path: str = "", + model_tag: str = "latest", + model_bitrate: str = "8kbps", + n_quantizers: int = None, + device: str = "cuda", + model_type: str = "44khz", + win_duration: float = 5.0, + verbose: bool = False, +): + """Encode audio files in input path to .dac format. + + Parameters + ---------- + input : str + Path to input audio file or directory + output : str, optional + Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + n_quantizers : int, optional + Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. + device : str, optional + Device to use, by default "cuda" + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, + ) + generator.to(device) + generator.eval() + kwargs = {"n_quantizers": n_quantizers} + + # Find all audio files in input path + input = Path(input) + audio_files = util.find_audio(input) + + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(audio_files)), desc="Encoding files"): + # Load file + signal = AudioSignal(audio_files[i]) + + # Encode audio to .dac format + artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) + + # Compute output path + relative_path = audio_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = audio_files[i] + output_name = relative_path.with_suffix(".dac").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + artifact.save(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + encode() diff --git a/indextts/s2mel/hf_utils.py b/indextts/s2mel/hf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c95131931093d5a243a51c6b1b42d65496e2307 --- /dev/null +++ b/indextts/s2mel/hf_utils.py @@ -0,0 +1,12 @@ +import os +from huggingface_hub import hf_hub_download + + +def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.yml"): + os.makedirs("./checkpoints", exist_ok=True) + model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints") + if config_filename is None: + return model_path + config_path = hf_hub_download(repo_id=repo_id, filename=config_filename, cache_dir="./checkpoints") + + return model_path, config_path \ No newline at end of file diff --git a/indextts/s2mel/modules/alias_free_torch/__init__.py b/indextts/s2mel/modules/alias_free_torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3bccdb97a3706bcb7149f48e04178cf00a5e877 --- /dev/null +++ b/indextts/s2mel/modules/alias_free_torch/__init__.py @@ -0,0 +1,5 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +from .filter import * +from .resample import * +from .act import * diff --git a/indextts/s2mel/modules/alias_free_torch/act.py b/indextts/s2mel/modules/alias_free_torch/act.py new file mode 100644 index 0000000000000000000000000000000000000000..779d58d5f1e889f8b639dd019a0ce951e69e4cfb --- /dev/null +++ b/indextts/s2mel/modules/alias_free_torch/act.py @@ -0,0 +1,29 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch.nn as nn +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x diff --git a/indextts/s2mel/modules/alias_free_torch/filter.py b/indextts/s2mel/modules/alias_free_torch/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..ece8e02fce0e65e13522e990a80d1bfeeffd46ba --- /dev/null +++ b/indextts/s2mel/modules/alias_free_torch/filter.py @@ -0,0 +1,96 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if "sinc" in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +def kaiser_sinc_filter1d( + cutoff, half_width, kernel_size +): # return filter [1,1,kernel_size] + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = "replicate", + kernel_size: int = 12, + ): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + # input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + return out diff --git a/indextts/s2mel/modules/alias_free_torch/resample.py b/indextts/s2mel/modules/alias_free_torch/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..ee993b10339141b469b67c3e11f5d73c5f4e0bca --- /dev/null +++ b/indextts/s2mel/modules/alias_free_torch/resample.py @@ -0,0 +1,57 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + xx = self.lowpass(x) + + return xx diff --git a/indextts/s2mel/modules/audio.py b/indextts/s2mel/modules/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..eea6c6784745e2e9718d85144f15ba57879e7698 --- /dev/null +++ b/indextts/s2mel/modules/audio.py @@ -0,0 +1,82 @@ +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): +# if torch.min(y) < -1.0: +# print("min value is ", torch.min(y)) +# if torch.max(y) > 1.0: +# print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(sampling_rate) + "_" + str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/indextts/s2mel/modules/bigvgan/activations.py b/indextts/s2mel/modules/bigvgan/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..61f2808a5466b3cf4d041059700993af5527dd29 --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/activations.py @@ -0,0 +1,120 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x \ No newline at end of file diff --git a/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py new file mode 100644 index 0000000000000000000000000000000000000000..01a25ffc04edbdfe17859653690e3005d441712f --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py @@ -0,0 +1,77 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +from ..torch.resample import UpSample1d, DownSample1d + +# load fused CUDA kernel: this enables importing anti_alias_activation_cuda +from ..cuda import load + +anti_alias_activation_cuda = load.load() + + +class FusedAntiAliasActivation(torch.autograd.Function): + """ + Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs. + The hyperparameters are hard-coded in the kernel to maximize speed. + NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters. + """ + + @staticmethod + def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta): + activation_results = anti_alias_activation_cuda.forward( + inputs, up_ftr, down_ftr, alpha, beta + ) + + return activation_results + + @staticmethod + def backward(ctx, output_grads): + raise NotImplementedError + return output_grads, None, None + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + fused: bool = True, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + self.fused = fused # Whether to use fused CUDA kernel or not + + def forward(self, x): + if not self.fused: + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + else: + if self.act.__class__.__name__ == "Snake": + beta = self.act.alpha.data # Snake uses same params for alpha and beta + else: + beta = ( + self.act.beta.data + ) # Snakebeta uses different params for alpha and beta + alpha = self.act.alpha.data + if ( + not self.act.alpha_logscale + ): # Exp baked into cuda kernel, cancel it out with a log + alpha = torch.log(alpha) + beta = torch.log(beta) + + x = FusedAntiAliasActivation.apply( + x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta + ) + return x diff --git a/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c5651f77143bd678169eb11564a7cf7a7969a59e --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp @@ -0,0 +1,23 @@ +/* coding=utf-8 + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #include + +extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)"); +} \ No newline at end of file diff --git a/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..8c442334869fe72d639ec203fa4fac07f96a0ee1 --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu @@ -0,0 +1,246 @@ +/* coding=utf-8 + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "type_shim.h" +#include +#include +#include +#include +#include + +namespace +{ + // Hard-coded hyperparameters + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4; + constexpr int BUFFER_SIZE = 32; + constexpr int FILTER_SIZE = 12; + constexpr int HALF_FILTER_SIZE = 6; + constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl + constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl + constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl + + template + __global__ void anti_alias_activation_forward( + output_t *dst, + const input_t *src, + const input_t *up_ftr, + const input_t *down_ftr, + const input_t *alpha, + const input_t *beta, + int batch_size, + int channels, + int seq_len) + { + // Up and downsample filters + input_t up_filter[FILTER_SIZE]; + input_t down_filter[FILTER_SIZE]; + + // Load data from global memory including extra indices reserved for replication paddings + input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0}; + input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0}; + + // Output stores downsampled output before writing to dst + output_t output[BUFFER_SIZE]; + + // blockDim/threadIdx = (128, 1, 1) + // gridDim/blockIdx = (seq_blocks, channels, batches) + int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); + int local_offset = threadIdx.x * BUFFER_SIZE; + int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset; + + // intermediate have double the seq_len + int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2; + int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset; + + // Get values needed for replication padding before moving pointer + const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); + input_t seq_left_most_value = right_most_pntr[0]; + input_t seq_right_most_value = right_most_pntr[seq_len - 1]; + + // Move src and dst pointers + src += block_offset + local_offset; + dst += block_offset + local_offset; + + // Alpha and beta values for snake activatons. Applies exp by default + alpha = alpha + blockIdx.y; + input_t alpha_val = expf(alpha[0]); + beta = beta + blockIdx.y; + input_t beta_val = expf(beta[0]); + + #pragma unroll + for (int it = 0; it < FILTER_SIZE; it += 1) + { + up_filter[it] = up_ftr[it]; + down_filter[it] = down_ftr[it]; + } + + // Apply replication padding for upsampling, matching torch impl + #pragma unroll + for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1) + { + int element_index = seq_offset + it; // index for element + if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD)) + { + elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value; + } + if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD)) + { + elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value; + } + if ((element_index >= 0) && (element_index < seq_len)) + { + elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it]; + } + } + + // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later + #pragma unroll + for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1) + { + input_t acc = 0.0; + int element_index = intermediate_seq_offset + it; // index for intermediate + #pragma unroll + for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) + { + if ((element_index + f_idx) >= 0) + { + acc += up_filter[f_idx] * elements[it + f_idx]; + } + } + intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc; + } + + // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later + double no_div_by_zero = 0.000000001; + #pragma unroll + for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1) + { + intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val); + } + + // Apply replication padding before downsampling conv from intermediates + #pragma unroll + for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1) + { + intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT]; + } + #pragma unroll + for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1) + { + intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1]; + } + + // Apply downsample strided convolution (assuming stride=2) from intermediates + #pragma unroll + for (int it = 0; it < BUFFER_SIZE; it += 1) + { + input_t acc = 0.0; + #pragma unroll + for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) + { + // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation + acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT]; + } + output[it] = acc; + } + + // Write output to dst + #pragma unroll + for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG) + { + int element_index = seq_offset + it; + if (element_index < seq_len) + { + dst[it] = output[it]; + } + } + + } + + template + void dispatch_anti_alias_activation_forward( + output_t *dst, + const input_t *src, + const input_t *up_ftr, + const input_t *down_ftr, + const input_t *alpha, + const input_t *beta, + int batch_size, + int channels, + int seq_len) + { + if (seq_len == 0) + { + return; + } + else + { + // Use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + constexpr int seq_len_per_block = 4096; + int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block; + dim3 blocks(blocks_per_seq_len, channels, batch_size); + dim3 threads(threads_per_block, 1, 1); + + anti_alias_activation_forward + <<>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len); + } + } +} + +extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta) +{ + // Input is a 3d tensor with dimensions [batches, channels, seq_len] + const int batches = input.size(0); + const int channels = input.size(1); + const int seq_len = input.size(2); + + // Output + auto act_options = input.options().requires_grad(false); + + torch::Tensor anti_alias_activation_results = + torch::empty({batches, channels, seq_len}, act_options); + + void *input_ptr = static_cast(input.data_ptr()); + void *up_filter_ptr = static_cast(up_filter.data_ptr()); + void *down_filter_ptr = static_cast(down_filter.data_ptr()); + void *alpha_ptr = static_cast(alpha.data_ptr()); + void *beta_ptr = static_cast(beta.data_ptr()); + void *anti_alias_activation_results_ptr = static_cast(anti_alias_activation_results.data_ptr()); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch anti alias activation_forward", + dispatch_anti_alias_activation_forward( + reinterpret_cast(anti_alias_activation_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(up_filter_ptr), + reinterpret_cast(down_filter_ptr), + reinterpret_cast(alpha_ptr), + reinterpret_cast(beta_ptr), + batches, + channels, + seq_len);); + return anti_alias_activation_results; +} \ No newline at end of file diff --git a/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h new file mode 100644 index 0000000000000000000000000000000000000000..25818b2edf4cb0dc9130e62c7c4de8d16a01baa5 --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h @@ -0,0 +1,29 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5d01de398249e75e9e2298958764acb436edba --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +import os +import pathlib +import subprocess + +from torch.utils import cpp_extension + +""" +Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels. +Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below +""" +os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + +def load(): + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) + if int(bare_metal_major) >= 11: + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,code=sm_80") + + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / "build" + _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=[ + "-O3", + ], + extra_cuda_cflags=[ + "-O3", + "-gencode", + "arch=compute_70,code=sm_70", + "--use_fast_math", + ] + + extra_cuda_flags + + cc_flag, + verbose=True, + ) + + extra_cuda_flags = [ + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ] + + sources = [ + srcpath / "anti_alias_activation.cpp", + srcpath / "anti_alias_activation_cuda.cu", + ] + anti_alias_activation_cuda = _cpp_extention_load_helper( + "anti_alias_activation_cuda", sources, extra_cuda_flags + ) + + return anti_alias_activation_cuda + + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") diff --git a/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h new file mode 100644 index 0000000000000000000000000000000000000000..5db7e8a397e982d4d30d16ab6060814b98b7ab83 --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h @@ -0,0 +1,92 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "compat.h" + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch (TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch (TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } diff --git a/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py b/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f756ed83f87f9839e457b240f60469bc187707d --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .filter import * +from .resample import * +from .act import * diff --git a/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py b/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py new file mode 100644 index 0000000000000000000000000000000000000000..a6693aac602d7b331d6149522685dd512a26d277 --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py @@ -0,0 +1,30 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x diff --git a/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py b/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..0fa35b0d5ddf8d6cb04cd9d47364ca033cebcd32 --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py @@ -0,0 +1,101 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if "sinc" in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d( + cutoff, half_width, kernel_size +): # return filter [1,1,kernel_size] + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + """ + Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal. + """ + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = "replicate", + kernel_size: int = 12, + ): + """ + kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible. + """ + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + # Input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + return out diff --git a/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py b/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..a35380f5a2b0767069d8e3a64e01e090299ee2ab --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py @@ -0,0 +1,58 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + xx = self.lowpass(x) + + return xx diff --git a/indextts/s2mel/modules/bigvgan/bigvgan.py b/indextts/s2mel/modules/bigvgan/bigvgan.py new file mode 100644 index 0000000000000000000000000000000000000000..65f0cc4fefdf9e038beed968325da324e67fb565 --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/bigvgan.py @@ -0,0 +1,492 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import os +import json +from pathlib import Path +from typing import Optional, Union, Dict + +import torch +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +from . import activations +from .utils import init_weights, get_padding +from .alias_free_activation.torch.act import Activation1d as TorchActivation1d +from .env import AttrDict + +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download + + +def load_hparams_from_json(path) -> AttrDict: + with open(path) as f: + data = f.read() + return AttrDict(json.loads(data)) + + +class AMPBlock1(torch.nn.Module): + """ + AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer. + AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1 + + Args: + h (AttrDict): Hyperparameters. + channels (int): Number of convolution channels. + kernel_size (int): Size of the convolution kernel. Default is 3. + dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5). + activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None. + """ + + def __init__( + self, + h: AttrDict, + channels: int, + kernel_size: int = 3, + dilation: tuple = (1, 3, 5), + activation: str = None, + ): + super().__init__() + + self.h = h + + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=d, + padding=get_padding(kernel_size, d), + ) + ) + for d in dilation + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ) + for _ in range(len(dilation)) + ] + ) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len( + self.convs2 + ) # Total number of conv layers + + # Select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + from .alias_free_activation.cuda.activation1d import ( + Activation1d as CudaActivation1d, + ) + + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + + # Activation functions + if activation == "snake": + self.activations = nn.ModuleList( + [ + Activation1d( + activation=activations.Snake( + channels, alpha_logscale=h.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + elif activation == "snakebeta": + self.activations = nn.ModuleList( + [ + Activation1d( + activation=activations.SnakeBeta( + channels, alpha_logscale=h.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class AMPBlock2(torch.nn.Module): + """ + AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer. + Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1 + + Args: + h (AttrDict): Hyperparameters. + channels (int): Number of convolution channels. + kernel_size (int): Size of the convolution kernel. Default is 3. + dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5). + activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None. + """ + + def __init__( + self, + h: AttrDict, + channels: int, + kernel_size: int = 3, + dilation: tuple = (1, 3, 5), + activation: str = None, + ): + super().__init__() + + self.h = h + + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=d, + padding=get_padding(kernel_size, d), + ) + ) + for d in dilation + ] + ) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) # Total number of conv layers + + # Select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + from .alias_free_activation.cuda.activation1d import ( + Activation1d as CudaActivation1d, + ) + + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + + # Activation functions + if activation == "snake": + self.activations = nn.ModuleList( + [ + Activation1d( + activation=activations.Snake( + channels, alpha_logscale=h.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + elif activation == "snakebeta": + self.activations = nn.ModuleList( + [ + Activation1d( + activation=activations.SnakeBeta( + channels, alpha_logscale=h.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + for c, a in zip(self.convs, self.activations): + xt = a(x) + xt = c(xt) + x = xt + x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class BigVGAN( + torch.nn.Module, + PyTorchModelHubMixin, + library_name="bigvgan", + repo_url="https://github.com/NVIDIA/BigVGAN", + docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md", + pipeline_tag="audio-to-audio", + license="mit", + tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"], +): + """ + BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks). + New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks. + + Args: + h (AttrDict): Hyperparameters. + use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels. + + Note: + - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported. + - Ensure that the activation function is correctly specified in the hyperparameters (h.activation). + """ + + def __init__(self, h: AttrDict, use_cuda_kernel: bool = False): + super().__init__() + self.h = h + self.h["use_cuda_kernel"] = use_cuda_kernel + + # Select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + from .alias_free_activation.cuda.activation1d import ( + Activation1d as CudaActivation1d, + ) + + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + + # Pre-conv + self.conv_pre = weight_norm( + Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + ) + + # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + if h.resblock == "1": + resblock_class = AMPBlock1 + elif h.resblock == "2": + resblock_class = AMPBlock2 + else: + raise ValueError( + f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}" + ) + + # Transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + nn.ModuleList( + [ + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2 ** i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ] + ) + ) + + # Residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append( + resblock_class(h, ch, k, d, activation=h.activation) + ) + + # Post-conv + activation_post = ( + activations.Snake(ch, alpha_logscale=h.snake_logscale) + if h.activation == "snake" + else ( + activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) + if h.activation == "snakebeta" + else None + ) + ) + if activation_post is None: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + self.activation_post = Activation1d(activation=activation_post) + + # Whether to use bias for the final conv_post. Default to True for backward compatibility + self.use_bias_at_final = h.get("use_bias_at_final", True) + self.conv_post = weight_norm( + Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final) + ) + + # Weight initialization + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + # Final tanh activation. Defaults to True for backward compatibility + self.use_tanh_at_final = h.get("use_tanh_at_final", True) + + def forward(self, x): + # Pre-conv + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + # Upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # Post-conv + x = self.activation_post(x) + x = self.conv_post(x) + # Final tanh activation + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1] + + return x + + def remove_weight_norm(self): + try: + print("Removing weight norm...") + for l in self.ups: + for l_i in l: + remove_weight_norm(l_i) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + except ValueError: + print("[INFO] Model already removed weight norm. Skipping!") + pass + + # Additional methods for huggingface_hub support + def _save_pretrained(self, save_directory: Path) -> None: + """Save weights and config.json from a Pytorch model to a local directory.""" + + model_path = save_directory / "bigvgan_generator.pt" + torch.save({"generator": self.state_dict()}, model_path) + + config_path = save_directory / "config.json" + with open(config_path, "w") as config_file: + json.dump(self.h, config_file, indent=4) + + @classmethod + def _from_pretrained( + cls, + *, + model_id: str, + revision: str, + cache_dir: str, + force_download: bool, + proxies: Optional[Dict], + resume_download: bool, + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", # Additional argument + strict: bool = False, # Additional argument + use_cuda_kernel: bool = False, + **model_kwargs, + ): + """Load Pytorch pretrained weights and return the loaded model.""" + + # Download and load hyperparameters (h) used by BigVGAN + if os.path.isdir(model_id): + print("Loading config.json from local directory") + config_file = os.path.join(model_id, "config.json") + else: + config_file = hf_hub_download( + repo_id=model_id, + filename="config.json", + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + h = load_hparams_from_json(config_file) + + # instantiate BigVGAN using h + if use_cuda_kernel: + print( + f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!" + ) + print( + f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!" + ) + print( + f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis" + ) + model = cls(h, use_cuda_kernel=use_cuda_kernel) + + # Download and load pretrained generator weight + if os.path.isdir(model_id): + print("Loading weights from local directory") + model_file = os.path.join(model_id, "bigvgan_generator.pt") + else: + print(f"Loading weights from {model_id}") + model_file = hf_hub_download( + repo_id=model_id, + filename="bigvgan_generator.pt", + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + + checkpoint_dict = torch.load(model_file, map_location=map_location) + + try: + model.load_state_dict(checkpoint_dict["generator"]) + except RuntimeError: + print( + f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!" + ) + model.remove_weight_norm() + model.load_state_dict(checkpoint_dict["generator"]) + + return model \ No newline at end of file diff --git a/indextts/s2mel/modules/bigvgan/config.json b/indextts/s2mel/modules/bigvgan/config.json new file mode 100644 index 0000000000000000000000000000000000000000..635bd8975629bd6d4b51c409986944a281cfe7be --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/config.json @@ -0,0 +1,63 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": null, + "fmax_for_loss": null, + + "normalize_volume": true, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/indextts/s2mel/modules/bigvgan/env.py b/indextts/s2mel/modules/bigvgan/env.py new file mode 100644 index 0000000000000000000000000000000000000000..b8be238d4db710c8c9a338d336baea0138f18d1f --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/env.py @@ -0,0 +1,18 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) \ No newline at end of file diff --git a/indextts/s2mel/modules/bigvgan/meldataset.py b/indextts/s2mel/modules/bigvgan/meldataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5e89d7384755e725c946aff3884834e15e295a16 --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/meldataset.py @@ -0,0 +1,354 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import math +import os +import random +import torch +import torch.utils.data +import numpy as np +from librosa.util import normalize +from scipy.io.wavfile import read +from librosa.filters import mel as librosa_mel_fn +import pathlib +from tqdm import tqdm + +MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases) + + +def load_wav(full_path, sr_target): + sampling_rate, data = read(full_path) + if sampling_rate != sr_target: + raise RuntimeError( + f"Sampling rate of the file {full_path} is {sampling_rate} Hz, but the model requires {sr_target} Hz" + ) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + return dynamic_range_compression_torch(magnitudes) + + +def spectral_de_normalize_torch(magnitudes): + return dynamic_range_decompression_torch(magnitudes) + + +mel_basis_cache = {} +hann_window_cache = {} + + +def mel_spectrogram( + y: torch.Tensor, + n_fft: int, + num_mels: int, + sampling_rate: int, + hop_size: int, + win_size: int, + fmin: int, + fmax: int = None, + center: bool = False, +) -> torch.Tensor: + """ + Calculate the mel spectrogram of an input signal. + This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft). + + Args: + y (torch.Tensor): Input signal. + n_fft (int): FFT size. + num_mels (int): Number of mel bins. + sampling_rate (int): Sampling rate of the input signal. + hop_size (int): Hop size for STFT. + win_size (int): Window size for STFT. + fmin (int): Minimum frequency for mel filterbank. + fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn + center (bool): Whether to pad the input to center the frames. Default is False. + + Returns: + torch.Tensor: Mel spectrogram. + """ + if torch.min(y) < -1.0: + print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}") + if torch.max(y) > 1.0: + print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}") + + device = y.device + key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}" + + if key not in mel_basis_cache: + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) + hann_window_cache[key] = torch.hann_window(win_size).to(device) + + mel_basis = mel_basis_cache[key] + hann_window = hann_window_cache[key] + + padding = (n_fft - hop_size) // 2 + y = torch.nn.functional.pad( + y.unsqueeze(1), (padding, padding), mode="reflect" + ).squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + + mel_spec = torch.matmul(mel_basis, spec) + mel_spec = spectral_normalize_torch(mel_spec) + + return mel_spec + + +def get_mel_spectrogram(wav, h): + """ + Generate mel spectrogram from a waveform using given hyperparameters. + + Args: + wav (torch.Tensor): Input waveform. + h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax. + + Returns: + torch.Tensor: Mel spectrogram. + """ + return mel_spectrogram( + wav, + h.n_fft, + h.num_mels, + h.sampling_rate, + h.hop_size, + h.win_size, + h.fmin, + h.fmax, + ) + + +def get_dataset_filelist(a): + training_files = [] + validation_files = [] + list_unseen_validation_files = [] + + with open(a.input_training_file, "r", encoding="utf-8") as fi: + training_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") + for x in fi.read().split("\n") + if len(x) > 0 + ] + print(f"first training file: {training_files[0]}") + + with open(a.input_validation_file, "r", encoding="utf-8") as fi: + validation_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") + for x in fi.read().split("\n") + if len(x) > 0 + ] + print(f"first validation file: {validation_files[0]}") + + for i in range(len(a.list_input_unseen_validation_file)): + with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi: + unseen_validation_files = [ + os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav") + for x in fi.read().split("\n") + if len(x) > 0 + ] + print( + f"first unseen {i}th validation fileset: {unseen_validation_files[0]}" + ) + list_unseen_validation_files.append(unseen_validation_files) + + return training_files, validation_files, list_unseen_validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__( + self, + training_files, + hparams, + segment_size, + n_fft, + num_mels, + hop_size, + win_size, + sampling_rate, + fmin, + fmax, + split=True, + shuffle=True, + n_cache_reuse=1, + device=None, + fmax_loss=None, + fine_tuning=False, + base_mels_path=None, + is_seen=True, + ): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.hparams = hparams + self.is_seen = is_seen + if self.is_seen: + self.name = pathlib.Path(self.audio_files[0]).parts[0] + else: + self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/") + + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.cached_wav = None + self.n_cache_reuse = n_cache_reuse + self._cache_ref_count = 0 + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + print("[INFO] checking dataset integrity...") + for i in tqdm(range(len(self.audio_files))): + assert os.path.exists( + self.audio_files[i] + ), f"{self.audio_files[i]} not found" + + def __getitem__(self, index): + filename = self.audio_files[index] + if self._cache_ref_count == 0: + audio, sampling_rate = load_wav(filename, self.sampling_rate) + audio = audio / MAX_WAV_VALUE + if not self.fine_tuning: + audio = normalize(audio) * 0.95 + self.cached_wav = audio + if sampling_rate != self.sampling_rate: + raise ValueError( + f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR" + ) + self._cache_ref_count = self.n_cache_reuse + else: + audio = self.cached_wav + self._cache_ref_count -= 1 + + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) + + if not self.fine_tuning: + if self.split: + if audio.size(1) >= self.segment_size: + max_audio_start = audio.size(1) - self.segment_size + audio_start = random.randint(0, max_audio_start) + audio = audio[:, audio_start : audio_start + self.segment_size] + else: + audio = torch.nn.functional.pad( + audio, (0, self.segment_size - audio.size(1)), "constant" + ) + + mel = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax, + center=False, + ) + else: # Validation step + # Match audio length to self.hop_size * n for evaluation + if (audio.size(1) % self.hop_size) != 0: + audio = audio[:, : -(audio.size(1) % self.hop_size)] + mel = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax, + center=False, + ) + assert ( + audio.shape[1] == mel.shape[2] * self.hop_size + ), f"audio shape {audio.shape} mel shape {mel.shape}" + + else: + mel = np.load( + os.path.join( + self.base_mels_path, + os.path.splitext(os.path.split(filename)[-1])[0] + ".npy", + ) + ) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start : mel_start + frames_per_seg] + audio = audio[ + :, + mel_start + * self.hop_size : (mel_start + frames_per_seg) + * self.hop_size, + ] + else: + mel = torch.nn.functional.pad( + mel, (0, frames_per_seg - mel.size(2)), "constant" + ) + audio = torch.nn.functional.pad( + audio, (0, self.segment_size - audio.size(1)), "constant" + ) + + mel_loss = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax_loss, + center=False, + ) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + def __len__(self): + return len(self.audio_files) diff --git a/indextts/s2mel/modules/bigvgan/utils.py b/indextts/s2mel/modules/bigvgan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..da98a24cf1447778305563f8e909f30b06e06b26 --- /dev/null +++ b/indextts/s2mel/modules/bigvgan/utils.py @@ -0,0 +1,99 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import glob +import os +import matplotlib +import torch +from torch.nn.utils import weight_norm + +matplotlib.use("Agg") +import matplotlib.pylab as plt +from .meldataset import MAX_WAV_VALUE +from scipy.io.wavfile import write + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def plot_spectrogram_clipped(spectrogram, clip_max=2.0): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow( + spectrogram, + aspect="auto", + origin="lower", + interpolation="none", + vmin=1e-6, + vmax=clip_max, + ) + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print(f"Loading '{filepath}'") + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print(f"Saving checkpoint to {filepath}") + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix, renamed_file=None): + # Fallback to original scanning logic first + pattern = os.path.join(cp_dir, prefix + "????????") + cp_list = glob.glob(pattern) + + if len(cp_list) > 0: + last_checkpoint_path = sorted(cp_list)[-1] + print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'") + return last_checkpoint_path + + # If no pattern-based checkpoints are found, check for renamed file + if renamed_file: + renamed_path = os.path.join(cp_dir, renamed_file) + if os.path.isfile(renamed_path): + print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'") + return renamed_path + + return None + + +def save_audio(audio, path, sr): + # wav: torch with 1d shape + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype("int16") + write(path, sr, audio) diff --git a/indextts/s2mel/modules/campplus/DTDNN.py b/indextts/s2mel/modules/campplus/DTDNN.py new file mode 100644 index 0000000000000000000000000000000000000000..e973a8716bb9ad80ae5b4f1cbf32c8a4213e6b09 --- /dev/null +++ b/indextts/s2mel/modules/campplus/DTDNN.py @@ -0,0 +1,115 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +from collections import OrderedDict + +import torch +from torch import nn +import torch.nn.functional as F + +from indextts.s2mel.modules.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, BasicResBlock, get_nonlinear + + +class FCM(nn.Module): + def __init__(self, + block=BasicResBlock, + num_blocks=[2, 2], + m_channels=32, + feat_dim=80): + super(FCM, self).__init__() + self.in_planes = m_channels + self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(m_channels) + + self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2) + self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2) + + self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(m_channels) + self.out_channels = m_channels * (feat_dim // 8) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = x.unsqueeze(1) + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = F.relu(self.bn2(self.conv2(out))) + + shape = out.shape + out = out.reshape(shape[0], shape[1]*shape[2], shape[3]) + return out + +class CAMPPlus(nn.Module): + def __init__(self, + feat_dim=80, + embedding_size=512, + growth_rate=32, + bn_size=4, + init_channels=128, + config_str='batchnorm-relu', + memory_efficient=True): + super(CAMPPlus, self).__init__() + + self.head = FCM(feat_dim=feat_dim) + channels = self.head.out_channels + + self.xvector = nn.Sequential( + OrderedDict([ + + ('tdnn', + TDNNLayer(channels, + init_channels, + 5, + stride=2, + dilation=1, + padding=-1, + config_str=config_str)), + ])) + channels = init_channels + for i, (num_layers, kernel_size, + dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))): + block = CAMDenseTDNNBlock(num_layers=num_layers, + in_channels=channels, + out_channels=growth_rate, + bn_channels=bn_size * growth_rate, + kernel_size=kernel_size, + dilation=dilation, + config_str=config_str, + memory_efficient=memory_efficient) + self.xvector.add_module('block%d' % (i + 1), block) + channels = channels + num_layers * growth_rate + self.xvector.add_module( + 'transit%d' % (i + 1), + TransitLayer(channels, + channels // 2, + bias=False, + config_str=config_str)) + channels //= 2 + + self.xvector.add_module( + 'out_nonlinear', get_nonlinear(config_str, channels)) + + self.xvector.add_module('stats', StatsPool()) + self.xvector.add_module( + 'dense', + DenseLayer(channels * 2, embedding_size, config_str='batchnorm_')) + + for m in self.modules(): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x): + x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) + x = self.head(x) + x = self.xvector(x) + return x diff --git a/indextts/s2mel/modules/campplus/classifier.py b/indextts/s2mel/modules/campplus/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..5c709e7da673ac43fdc41c4d5babdde26368f6a4 --- /dev/null +++ b/indextts/s2mel/modules/campplus/classifier.py @@ -0,0 +1,70 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modules.campplus.layers import DenseLayer + + +class CosineClassifier(nn.Module): + def __init__( + self, + input_dim, + num_blocks=0, + inter_dim=512, + out_neurons=1000, + ): + + super().__init__() + self.blocks = nn.ModuleList() + + for index in range(num_blocks): + self.blocks.append( + DenseLayer(input_dim, inter_dim, config_str='batchnorm') + ) + input_dim = inter_dim + + self.weight = nn.Parameter( + torch.FloatTensor(out_neurons, input_dim) + ) + nn.init.xavier_uniform_(self.weight) + + def forward(self, x): + # x: [B, dim] + for layer in self.blocks: + x = layer(x) + + # normalized + x = F.linear(F.normalize(x), F.normalize(self.weight)) + return x + +class LinearClassifier(nn.Module): + def __init__( + self, + input_dim, + num_blocks=0, + inter_dim=512, + out_neurons=1000, + ): + + super().__init__() + self.blocks = nn.ModuleList() + + self.nonlinear = nn.ReLU(inplace=True) + for index in range(num_blocks): + self.blocks.append( + DenseLayer(input_dim, inter_dim, bias=True) + ) + input_dim = inter_dim + + self.linear = nn.Linear(input_dim, out_neurons, bias=True) + + def forward(self, x): + # x: [B, dim] + x = self.nonlinear(x) + for layer in self.blocks: + x = layer(x) + x = self.linear(x) + return x \ No newline at end of file diff --git a/indextts/s2mel/modules/campplus/layers.py b/indextts/s2mel/modules/campplus/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..aae83e194779c7528cd251b76c96d814f442775c --- /dev/null +++ b/indextts/s2mel/modules/campplus/layers.py @@ -0,0 +1,253 @@ +# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from torch import nn + + +def get_nonlinear(config_str, channels): + nonlinear = nn.Sequential() + for name in config_str.split('-'): + if name == 'relu': + nonlinear.add_module('relu', nn.ReLU(inplace=True)) + elif name == 'prelu': + nonlinear.add_module('prelu', nn.PReLU(channels)) + elif name == 'batchnorm': + nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels)) + elif name == 'batchnorm_': + nonlinear.add_module('batchnorm', + nn.BatchNorm1d(channels, affine=False)) + else: + raise ValueError('Unexpected module ({}).'.format(name)) + return nonlinear + +def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2): + mean = x.mean(dim=dim) + std = x.std(dim=dim, unbiased=unbiased) + stats = torch.cat([mean, std], dim=-1) + if keepdim: + stats = stats.unsqueeze(dim=dim) + return stats + + +class StatsPool(nn.Module): + def forward(self, x): + return statistics_pooling(x) + + +class TDNNLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + bias=False, + config_str='batchnorm-relu'): + super(TDNNLayer, self).__init__() + if padding < 0: + assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format( + kernel_size) + padding = (kernel_size - 1) // 2 * dilation + self.linear = nn.Conv1d(in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + self.nonlinear = get_nonlinear(config_str, out_channels) + + def forward(self, x): + x = self.linear(x) + x = self.nonlinear(x) + return x + + +class CAMLayer(nn.Module): + def __init__(self, + bn_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + bias, + reduction=2): + super(CAMLayer, self).__init__() + self.linear_local = nn.Conv1d(bn_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1) + self.relu = nn.ReLU(inplace=True) + self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + y = self.linear_local(x) + context = x.mean(-1, keepdim=True)+self.seg_pooling(x) + context = self.relu(self.linear1(context)) + m = self.sigmoid(self.linear2(context)) + return y*m + + def seg_pooling(self, x, seg_len=100, stype='avg'): + if stype == 'avg': + seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) + elif stype == 'max': + seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) + else: + raise ValueError('Wrong segment pooling type.') + shape = seg.shape + seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1) + seg = seg[..., :x.shape[-1]] + return seg + + +class CAMDenseTDNNLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + bn_channels, + kernel_size, + stride=1, + dilation=1, + bias=False, + config_str='batchnorm-relu', + memory_efficient=False): + super(CAMDenseTDNNLayer, self).__init__() + assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format( + kernel_size) + padding = (kernel_size - 1) // 2 * dilation + self.memory_efficient = memory_efficient + self.nonlinear1 = get_nonlinear(config_str, in_channels) + self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False) + self.nonlinear2 = get_nonlinear(config_str, bn_channels) + self.cam_layer = CAMLayer(bn_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias) + + def bn_function(self, x): + return self.linear1(self.nonlinear1(x)) + + def forward(self, x): + if self.training and self.memory_efficient: + x = cp.checkpoint(self.bn_function, x) + else: + x = self.bn_function(x) + x = self.cam_layer(self.nonlinear2(x)) + return x + + +class CAMDenseTDNNBlock(nn.ModuleList): + def __init__(self, + num_layers, + in_channels, + out_channels, + bn_channels, + kernel_size, + stride=1, + dilation=1, + bias=False, + config_str='batchnorm-relu', + memory_efficient=False): + super(CAMDenseTDNNBlock, self).__init__() + for i in range(num_layers): + layer = CAMDenseTDNNLayer(in_channels=in_channels + i * out_channels, + out_channels=out_channels, + bn_channels=bn_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + bias=bias, + config_str=config_str, + memory_efficient=memory_efficient) + self.add_module('tdnnd%d' % (i + 1), layer) + + def forward(self, x): + for layer in self: + x = torch.cat([x, layer(x)], dim=1) + return x + + +class TransitLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + bias=True, + config_str='batchnorm-relu'): + super(TransitLayer, self).__init__() + self.nonlinear = get_nonlinear(config_str, in_channels) + self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) + + def forward(self, x): + x = self.nonlinear(x) + x = self.linear(x) + return x + + +class DenseLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + bias=False, + config_str='batchnorm-relu'): + super(DenseLayer, self).__init__() + self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) + self.nonlinear = get_nonlinear(config_str, out_channels) + + def forward(self, x): + if len(x.shape) == 2: + x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1) + else: + x = self.linear(x) + x = self.nonlinear(x) + return x + + +class BasicResBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicResBlock, self).__init__() + self.conv1 = nn.Conv2d(in_planes, + planes, + kernel_size=3, + stride=(stride, 1), + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=(stride, 1), + bias=False), + nn.BatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out \ No newline at end of file diff --git a/indextts/s2mel/modules/commons.py b/indextts/s2mel/modules/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..35fced4bc696030e9c1384b005e1fdc614fc30b5 --- /dev/null +++ b/indextts/s2mel/modules/commons.py @@ -0,0 +1,632 @@ +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from munch import Munch +import json +import argparse +from torch.nn.parallel import DistributedDataParallel as DDP + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def slice_segments_audio(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to( + dtype=torch.long + ) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def avg_with_mask(x, mask): + assert mask.dtype == torch.float, "Mask should be float" + + if mask.ndim == 2: + mask = mask.unsqueeze(1) + + if mask.shape[1] == 1: + mask = mask.expand_as(x) + + return (x * mask).sum() / mask.sum() + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm + + +def log_norm(x, mean=-4, std=4, dim=2): + """ + normalized log mel -> mel -> norm -> log(norm) + """ + x = torch.log(torch.exp(x * std + mean).norm(dim=dim)) + return x + + +def load_F0_models(path): + # load F0 model + from .JDC.model import JDCNet + + F0_model = JDCNet(num_class=1, seq_len=192) + params = torch.load(path, map_location="cpu")["net"] + F0_model.load_state_dict(params) + _ = F0_model.train() + + return F0_model + + +def modify_w2v_forward(self, output_layer=15): + """ + change forward method of w2v encoder to get its intermediate layer output + :param self: + :param layer: + :return: + """ + from transformers.modeling_outputs import BaseModelOutput + + def forward( + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + conv_attention_mask = attention_mask + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states = hidden_states.masked_fill( + ~attention_mask.bool().unsqueeze(-1), 0.0 + ) + + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to( + dtype=hidden_states.dtype + ) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], + 1, + attention_mask.shape[-1], + attention_mask.shape[-1], + ) + + hidden_states = self.dropout(hidden_states) + + if self.embed_positions is not None: + relative_position_embeddings = self.embed_positions(hidden_states) + else: + relative_position_embeddings = None + + deepspeed_zero3_is_enabled = False + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = ( + True + if self.training and (dropout_probability < self.config.layerdrop) + else False + ) + if not skip_the_layer or deepspeed_zero3_is_enabled: + # under deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + relative_position_embeddings, + output_attentions, + conv_attention_mask, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if i == output_layer - 1: + break + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward + + +MATPLOTLIB_FLAG = False + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + import logging + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def normalize_f0(f0_sequence): + # Remove unvoiced frames (replace with -1) + voiced_indices = np.where(f0_sequence > 0)[0] + f0_voiced = f0_sequence[voiced_indices] + + # Convert to log scale + log_f0 = np.log2(f0_voiced) + + # Calculate mean and standard deviation + mean_f0 = np.mean(log_f0) + std_f0 = np.std(log_f0) + + # Normalize the F0 sequence + normalized_f0 = (log_f0 - mean_f0) / std_f0 + + # Create the normalized F0 sequence with unvoiced frames + normalized_sequence = np.zeros_like(f0_sequence) + normalized_sequence[voiced_indices] = normalized_f0 + normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames + + return normalized_sequence + + +class MyModel(nn.Module): + def __init__(self,args, use_emovec=False, use_gpt_latent=False): + super(MyModel, self).__init__() + from indextts.s2mel.modules.flow_matching import CFM + from indextts.s2mel.modules.length_regulator import InterpolateRegulator + + length_regulator = InterpolateRegulator( + channels=args.length_regulator.channels, + sampling_ratios=args.length_regulator.sampling_ratios, + is_discrete=args.length_regulator.is_discrete, + in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None, + vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False, + codebook_size=args.length_regulator.content_codebook_size, + n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1, + quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0, + f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False, + n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512, + ) + + if use_gpt_latent: + self.models = nn.ModuleDict({ + 'cfm': CFM(args), + 'length_regulator': length_regulator, + 'gpt_layer': torch.nn.Sequential(torch.nn.Linear(1280, 256), torch.nn.Linear(256, 128), torch.nn.Linear(128, 1024)) + }) + + else: + self.models = nn.ModuleDict({ + 'cfm': CFM(args), + 'length_regulator': length_regulator + }) + + def forward(self, x, target_lengths, prompt_len, cond, y): + x = self.models['cfm'](x, target_lengths, prompt_len, cond, y) + return x + + def forward2(self, S_ori,target_lengths,F0_ori): + x = self.models['length_regulator'](S_ori, ylens=target_lengths, f0=F0_ori) + return x + + def forward_emovec(self, x): + x = self.models['emo_layer'](x) + return x + + def forward_emo_encoder(self, x): + x = self.models['emo_encoder'](x) + return x + + def forward_gpt(self,x): + x = self.models['gpt_layer'](x) + return x + + + +def build_model(args, stage="DiT"): + if stage == "DiT": + from modules.flow_matching import CFM + from modules.length_regulator import InterpolateRegulator + + length_regulator = InterpolateRegulator( + channels=args.length_regulator.channels, + sampling_ratios=args.length_regulator.sampling_ratios, + is_discrete=args.length_regulator.is_discrete, + in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None, + vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False, + codebook_size=args.length_regulator.content_codebook_size, + n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1, + quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0, + f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False, + n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512, + ) + cfm = CFM(args) + nets = Munch( + cfm=cfm, + length_regulator=length_regulator, + ) + + elif stage == 'codec': + from dac.model.dac import Encoder + from modules.quantize import ( + FAquantizer, + ) + + encoder = Encoder( + d_model=args.DAC.encoder_dim, + strides=args.DAC.encoder_rates, + d_latent=1024, + causal=args.causal, + lstm=args.lstm, + ) + + quantizer = FAquantizer( + in_dim=1024, + n_p_codebooks=1, + n_c_codebooks=args.n_c_codebooks, + n_t_codebooks=2, + n_r_codebooks=3, + codebook_size=1024, + codebook_dim=8, + quantizer_dropout=0.5, + causal=args.causal, + separate_prosody_encoder=args.separate_prosody_encoder, + timbre_norm=args.timbre_norm, + ) + + nets = Munch( + encoder=encoder, + quantizer=quantizer, + ) + + elif stage == "mel_vocos": + from modules.vocos import Vocos + decoder = Vocos(args) + nets = Munch( + decoder=decoder, + ) + + else: + raise ValueError(f"Unknown stage: {stage}") + + return nets + + +def load_checkpoint( + model, + optimizer, + path, + load_only_params=True, + ignore_modules=[], + is_distributed=False, + load_ema=False, +): + state = torch.load(path, map_location="cpu") + params = state["net"] + if load_ema and "ema" in state: + print("Loading EMA") + for key in model: + i = 0 + for param_name in params[key]: + if "input_pos" in param_name: + continue + assert params[key][param_name].shape == state["ema"][key][0][i].shape + params[key][param_name] = state["ema"][key][0][i].clone() + i += 1 + for key in model: + if key in params and key not in ignore_modules: + if not is_distributed: + # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix + for k in list(params[key].keys()): + if k.startswith("module."): + params[key][k[len("module.") :]] = params[key][k] + del params[key][k] + model_state_dict = model[key].state_dict() + # 过滤出形状匹配的键值对 + filtered_state_dict = { + k: v + for k, v in params[key].items() + if k in model_state_dict and v.shape == model_state_dict[k].shape + } + skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys()) + if skipped_keys: + print( + f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}" + ) + print("%s loaded" % key) + model[key].load_state_dict(filtered_state_dict, strict=False) + _ = [model[key].eval() for key in model] + + if not load_only_params: + epoch = state["epoch"] + 1 + iters = state["iters"] + optimizer.load_state_dict(state["optimizer"]) + optimizer.load_scheduler_state_dict(state["scheduler"]) + + else: + epoch = 0 + iters = 0 + + return model, optimizer, epoch, iters + +def load_checkpoint2( + model, + optimizer, + path, + load_only_params=True, + ignore_modules=[], + is_distributed=False, + load_ema=False, +): + state = torch.load(path, map_location="cpu") + params = state["net"] + if load_ema and "ema" in state: + print("Loading EMA") + for key in model.models: + i = 0 + for param_name in params[key]: + if "input_pos" in param_name: + continue + assert params[key][param_name].shape == state["ema"][key][0][i].shape + params[key][param_name] = state["ema"][key][0][i].clone() + i += 1 + for key in model.models: + if key in params and key not in ignore_modules: + if not is_distributed: + # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix + for k in list(params[key].keys()): + if k.startswith("module."): + params[key][k[len("module.") :]] = params[key][k] + del params[key][k] + model_state_dict = model.models[key].state_dict() + # 过滤出形状匹配的键值对 + filtered_state_dict = { + k: v + for k, v in params[key].items() + if k in model_state_dict and v.shape == model_state_dict[k].shape + } + skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys()) + if skipped_keys: + print( + f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}" + ) + print("%s loaded" % key) + model.models[key].load_state_dict(filtered_state_dict, strict=False) + model.eval() +# _ = [model[key].eval() for key in model] + + if not load_only_params: + epoch = state["epoch"] + 1 + iters = state["iters"] + optimizer.load_state_dict(state["optimizer"]) + optimizer.load_scheduler_state_dict(state["scheduler"]) + + else: + epoch = 0 + iters = 0 + + return model, optimizer, epoch, iters + +def recursive_munch(d): + if isinstance(d, dict): + return Munch((k, recursive_munch(v)) for k, v in d.items()) + elif isinstance(d, list): + return [recursive_munch(v) for v in d] + else: + return d diff --git a/indextts/s2mel/modules/diffusion_transformer.py b/indextts/s2mel/modules/diffusion_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..23d6912b2b635d6490c522cec0a7f8d119720f10 --- /dev/null +++ b/indextts/s2mel/modules/diffusion_transformer.py @@ -0,0 +1,257 @@ +import torch +from torch import nn +import math + +from indextts.s2mel.modules.gpt_fast.model import ModelArgs, Transformer +from indextts.s2mel.modules.wavenet import WN +from indextts.s2mel.modules.commons import sequence_mask + +from torch.nn.utils import weight_norm + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + self.max_period = 10000 + self.scale = 1000 + + half = frequency_embedding_size // 2 + freqs = torch.exp( + -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ) + self.register_buffer("freqs", freqs) + + def timestep_embedding(self, t): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + + args = self.scale * t[:, None].float() * self.freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if self.frequency_embedding_size % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t) + t_emb = self.mlp(t_freq) + return t_emb + + +class StyleEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, input_size, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size) + self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True)) + self.input_size = input_size + self.dropout_prob = dropout_prob + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + else: + labels = self.style_in(labels) + embeddings = labels + return embeddings + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class DiT(torch.nn.Module): + def __init__( + self, + args + ): + super(DiT, self).__init__() + self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False + self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False + self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False + model_args = ModelArgs( + block_size=16384,#args.DiT.block_size, + n_layer=args.DiT.depth, + n_head=args.DiT.num_heads, + dim=args.DiT.hidden_dim, + head_dim=args.DiT.hidden_dim // args.DiT.num_heads, + vocab_size=1024, + uvit_skip_connection=self.uvit_skip_connection, + time_as_token=self.time_as_token, + ) + self.transformer = Transformer(model_args) + self.in_channels = args.DiT.in_channels + self.out_channels = args.DiT.in_channels + self.num_heads = args.DiT.num_heads + + self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True)) + + self.content_type = args.DiT.content_type # 'discrete' or 'continuous' + self.content_codebook_size = args.DiT.content_codebook_size # for discrete content + self.content_dim = args.DiT.content_dim # for continuous content + self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content + self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content + + self.is_causal = args.DiT.is_causal + + self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim) + + # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True)) + # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True)) + + input_pos = torch.arange(16384) + self.register_buffer("input_pos", input_pos) + + self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet + if self.final_layer_type == 'wavenet': + self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim) + self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim) + self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1) + self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim, + kernel_size=args.wavenet.kernel_size, + dilation_rate=args.wavenet.dilation_rate, + n_layers=args.wavenet.num_layers, + gin_channels=args.wavenet.hidden_dim, + p_dropout=args.wavenet.p_dropout, + causal=False) + self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim) + self.res_projection = nn.Linear(args.DiT.hidden_dim, + args.wavenet.hidden_dim) # residual connection from tranformer output to final output + self.wavenet_style_condition = args.wavenet.style_condition + assert args.DiT.style_condition == args.wavenet.style_condition + else: + self.final_mlp = nn.Sequential( + nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim), + nn.SiLU(), + nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels), + ) + self.transformer_style_condition = args.DiT.style_condition + + + self.class_dropout_prob = args.DiT.class_dropout_prob + self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim) + + self.long_skip_connection = args.DiT.long_skip_connection + self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim) + + self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 + + args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token), + args.DiT.hidden_dim) + if self.style_as_token: + self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim) + + def setup_caches(self, max_batch_size, max_seq_length): + self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False) + + def forward(self, x, prompt_x, x_lens, t, style, cond, mask_content=False): + """ + x (torch.Tensor): random noise + prompt_x (torch.Tensor): reference mel + zero mel + shape: (batch_size, 80, 795+1068) + x_lens (torch.Tensor): mel frames output + shape: (batch_size, mel_timesteps) + t (torch.Tensor): radshape: + shape: (batch_size) + style (torch.Tensor): reference global style + shape: (batch_size, 192) + cond (torch.Tensor): semantic info of reference audio and altered audio + shape: (batch_size, mel_timesteps(795+1069), 512) + + """ + class_dropout = False + if self.training and torch.rand(1) < self.class_dropout_prob: + class_dropout = True + if not self.training and mask_content: + class_dropout = True + # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection + cond_in_module = self.cond_projection + + B, _, T = x.size() + + + t1 = self.t_embedder(t) # (N, D) # t1 [2, 512] + cond = cond_in_module(cond) # cond [2,1863,512]->[2,1863,512] + + x = x.transpose(1, 2) # [2,1863,80] + prompt_x = prompt_x.transpose(1, 2) # [2,1863,80] + + x_in = torch.cat([x, prompt_x, cond], dim=-1) # 80+80+512=672 [2, 1863, 672] + + if self.transformer_style_condition and not self.style_as_token: # True and True + x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1) #[2, 1863, 864] + + if class_dropout: #False + x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0 # 80维后全置为0 + + x_in = self.cond_x_merge_linear(x_in) # (N, T, D) [2, 1863, 512] + + if self.style_as_token: # False + style = self.style_in(style) + style = torch.zeros_like(style) if class_dropout else style + x_in = torch.cat([style.unsqueeze(1), x_in], dim=1) + + if self.time_as_token: # False + x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1) + + x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True + input_pos = self.input_pos[:x_in.size(1)] # (T,) range(0,1863) + x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None # torch.Size([1, 1, 1863, 1863] + x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded) # [2, 1863, 512] + x_res = x_res[:, 1:] if self.time_as_token else x_res + x_res = x_res[:, 1:] if self.style_as_token else x_res + + if self.long_skip_connection: #True + x_res = self.skip_linear(torch.cat([x_res, x], dim=-1)) + if self.final_layer_type == 'wavenet': + x = self.conv1(x_res) + x = x.transpose(1, 2) + t2 = self.t_embedder2(t) + x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection( + x_res) # long residual connection + x = self.final_layer(x, t1).transpose(1, 2) + x = self.conv2(x) + else: + x = self.final_mlp(x_res) + x = x.transpose(1, 2) + # x [2,80,1863] + return x diff --git a/indextts/s2mel/modules/encodec.py b/indextts/s2mel/modules/encodec.py new file mode 100644 index 0000000000000000000000000000000000000000..9feeadd935dabd6642af6b1844494a51967292c3 --- /dev/null +++ b/indextts/s2mel/modules/encodec.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Convolutional layers wrappers and utilities.""" + +import math +import typing as tp +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +import typing as tp + +import einops + + +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = einops.rearrange(x, 'b ... t -> b t ...') + x = super().forward(x) + x = einops.rearrange(x, 'b t ... -> b ... t') + return + + +CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', + 'time_layer_norm', 'layer_norm', 'time_group_norm']) + + +def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == 'weight_norm': + return weight_norm(module) + elif norm == 'spectral_norm': + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == 'layer_norm': + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == 'time_group_norm': + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`. + """ + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class SConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, dilation: int = 1, + groups: int = 1, bias: bool = True, causal: bool = False, + norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = 'reflect', **kwargs): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1' + f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') + self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, + dilation=dilation, groups=groups, bias=bias, causal=causal, + norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + return self.conv(x) + + +class SConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, causal: bool = False, + norm: str = 'none', trim_right_ratio: float = 1., + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, + causal=causal, norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert self.causal or self.trim_right_ratio == 1., \ + "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y + +class SLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers) + self.hidden = None + + def forward(self, x): + x = x.permute(2, 0, 1) + if self.training: + y, _ = self.lstm(x) + else: + y, self.hidden = self.lstm(x, self.hidden) + if self.skip: + y = y + x + y = y.permute(1, 2, 0) + return y \ No newline at end of file diff --git a/indextts/s2mel/modules/flow_matching.py b/indextts/s2mel/modules/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..c396695c65e7eaa4024534ffec608cee8a80a970 --- /dev/null +++ b/indextts/s2mel/modules/flow_matching.py @@ -0,0 +1,171 @@ +from abc import ABC + +import torch +import torch.nn.functional as F + +from indextts.s2mel.modules.diffusion_transformer import DiT +from indextts.s2mel.modules.commons import sequence_mask + +from tqdm import tqdm + +class BASECFM(torch.nn.Module, ABC): + def __init__( + self, + args, + ): + super().__init__() + self.sigma_min = 1e-6 + + self.estimator = None + + self.in_channels = args.DiT.in_channels + + self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss() + + if hasattr(args.DiT, 'zero_prompt_speech_token'): + self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token + else: + self.zero_prompt_speech_token = False + + @torch.inference_mode() + def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5): + """Forward diffusion + + Args: + mu (torch.Tensor): semantic info of reference audio and altered audio + shape: (batch_size, mel_timesteps(795+1069), 512) + x_lens (torch.Tensor): mel frames output + shape: (batch_size, mel_timesteps) + prompt (torch.Tensor): reference mel + shape: (batch_size, 80, 795) + style (torch.Tensor): reference global style + shape: (batch_size, 192) + f0: None + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, 80, mel_timesteps) + """ + B, T = mu.size(0), mu.size(1) + z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + # t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span) + return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate) + + def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): semantic info of reference audio and altered audio + shape: (batch_size, mel_timesteps(795+1069), 512) + x_lens (torch.Tensor): mel frames output + shape: (batch_size, mel_timesteps) + prompt (torch.Tensor): reference mel + shape: (batch_size, 80, 795) + style (torch.Tensor): reference global style + shape: (batch_size, 192) + """ + t, _, _ = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + # apply prompt + prompt_len = prompt.size(-1) + prompt_x = torch.zeros_like(x) + prompt_x[..., :prompt_len] = prompt[..., :prompt_len] + x[..., :prompt_len] = 0 + if self.zero_prompt_speech_token: + mu[..., :prompt_len] = 0 + for step in tqdm(range(1, len(t_span))): + dt = t_span[step] - t_span[step - 1] + if inference_cfg_rate > 0: + # Stack original and CFG (null) inputs for batched processing + stacked_prompt_x = torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0) + stacked_style = torch.cat([style, torch.zeros_like(style)], dim=0) + stacked_mu = torch.cat([mu, torch.zeros_like(mu)], dim=0) + stacked_x = torch.cat([x, x], dim=0) + stacked_t = torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0) + + # Perform a single forward pass for both original and CFG inputs + stacked_dphi_dt = self.estimator( + stacked_x, stacked_prompt_x, x_lens, stacked_t, stacked_style, stacked_mu, + ) + + # Split the output back into the original and CFG components + dphi_dt, cfg_dphi_dt = stacked_dphi_dt.chunk(2, dim=0) + + # Apply CFG formula + dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt + else: + dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + x[:, :, :prompt_len] = 0 + + return sol[-1] + def forward(self, x1, x_lens, prompt_lens, mu, style): + """Computes diffusion loss + + Args: + mu (torch.Tensor): semantic info of reference audio and altered audio + shape: (batch_size, mel_timesteps(795+1069), 512) + x1: mel + x_lens (torch.Tensor): mel frames output + shape: (batch_size, mel_timesteps) + prompt (torch.Tensor): reference mel + shape: (batch_size, 80, 795) + style (torch.Tensor): reference global style + shape: (batch_size, 192) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = x1.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + prompt = torch.zeros_like(x1) + for bib in range(b): + prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]] + # range covered by prompt are set to 0 + y[bib, :, :prompt_lens[bib]] = 0 + if self.zero_prompt_speech_token: + mu[bib, :, :prompt_lens[bib]] = 0 + + estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(1).squeeze(1), style, mu, prompt_lens) + loss = 0 + for bib in range(b): + loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]]) + loss /= b + + return loss, estimator_out + (1 - self.sigma_min) * z + + + +class CFM(BASECFM): + def __init__(self, args): + super().__init__( + args + ) + if args.dit_type == "DiT": + self.estimator = DiT(args) + else: + raise NotImplementedError(f"Unknown diffusion type {args.dit_type}") diff --git a/indextts/s2mel/modules/gpt_fast/.ipynb_checkpoints/model-checkpoint.py b/indextts/s2mel/modules/gpt_fast/.ipynb_checkpoints/model-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..0f878bc3a9456f5d11446e68a797eb3b19d22bd7 --- /dev/null +++ b/indextts/s2mel/modules/gpt_fast/.ipynb_checkpoints/model-checkpoint.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if embedding is None: + return self.norm(input) + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + has_cross_attention: bool = False + context_dim: int = 0 + uvit_skip_connection: bool = False + time_as_token: bool = False + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + # self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [config for config in transformer_configs if config.lower() in str(name).lower()] + + # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match, + # take longer name (as it have more symbols matched) + if len(config) > 1: + config.sort(key=len, reverse=True) + assert len(config[0]) != len(config[1]), name # make sure only one 'best' match + + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000), + "7B": dict(n_layer=32, n_head=32, dim=4096), + "13B": dict(n_layer=40, n_head=40, dim=5120), + "30B": dict(n_layer=60, n_head=52, dim=6656), + "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, + rope_base=1000000), # CodeLlama-34B-Python-hf + "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), + "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), + "stories15M": dict(n_layer=6, n_head=6, dim=288), + "stories110M": dict(n_layer=12, n_head=12, dim=768), + + "llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, + vocab_size=128256, rope_base=500000), + "llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, + vocab_size=128256, rope_base=500000), +} + + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=True): + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + dtype = self.norm.project_layer.weight.dtype + device = self.norm.project_layer.weight.device + + if not self.training and use_kv_cache: + for b in self.layers: + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype).to(device) + + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, + self.config.rope_base, dtype).to(device) + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device) + self.use_kv_cache = use_kv_cache + self.uvit_skip_connection = self.config.uvit_skip_connection + if self.uvit_skip_connection: + self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2] + self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2] + else: + self.layers_emit_skip = [] + self.layers_receive_skip = [] + + def forward(self, + x: Tensor, + c: Tensor, + input_pos: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + context: Optional[Tensor] = None, + context_input_pos: Optional[Tensor] = None, + cross_attention_mask: Optional[Tensor] = None, + ) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + if mask is None: # in case of non-causal model + if not self.training and self.use_kv_cache: + mask = self.causal_mask[None, None, input_pos] + else: + mask = self.causal_mask[None, None, input_pos] + mask = mask[..., input_pos] + freqs_cis = self.freqs_cis[input_pos] + if context is not None: + context_freqs_cis = self.freqs_cis[context_input_pos] + else: + context_freqs_cis = None + skip_in_x_list = [] + for i, layer in enumerate(self.layers): + if self.uvit_skip_connection and i in self.layers_receive_skip: + skip_in_x = skip_in_x_list.pop(-1) + else: + skip_in_x = None + x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x) + if self.uvit_skip_connection and i in self.layers_emit_skip: + skip_in_x_list.append(x) + x = self.norm(x, c) + return x + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + + if config.has_cross_attention: + self.has_cross_attention = True + self.cross_attention = Attention(config, is_cross_attention=True) + self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + else: + self.has_cross_attention = False + + if config.uvit_skip_connection: + self.skip_in_linear = nn.Linear(config.dim * 2, config.dim) + self.uvit_skip_connection = True + else: + self.uvit_skip_connection = False + + self.time_as_token = config.time_as_token + + def forward(self, + x: Tensor, + c: Tensor, + input_pos: Tensor, + freqs_cis: Tensor, + mask: Tensor, + context: Optional[Tensor] = None, + context_freqs_cis: Optional[Tensor] = None, + cross_attention_mask: Optional[Tensor] = None, + skip_in_x: Optional[Tensor] = None, + ) -> Tensor: + c = None if self.time_as_token else c + if self.uvit_skip_connection and skip_in_x is not None: + x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1)) + h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos) + if self.has_cross_attention: + h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis) + out = h + self.feed_forward(self.ffn_norm(h, c)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs, is_cross_attention: bool = False): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + if is_cross_attention: + self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False) + self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False) + else: + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + # self._register_load_state_dict_pre_hook(self.load_hook) + + # def load_hook(self, state_dict, prefix, *args): + # if prefix + "wq.weight" in state_dict: + # wq = state_dict.pop(prefix + "wq.weight") + # wk = state_dict.pop(prefix + "wk.weight") + # wv = state_dict.pop(prefix + "wv.weight") + # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward(self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + context: Optional[Tensor] = None, + context_freqs_cis: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + if context is None: + q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) + context_seqlen = seqlen + else: + q = self.wq(x) + k, v = self.wkv(context).split([kv_size, kv_size], dim=-1) + context_seqlen = context.shape[1] + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head) + + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000, + dtype: torch.dtype = torch.bfloat16 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/indextts/s2mel/modules/gpt_fast/generate.py b/indextts/s2mel/modules/gpt_fast/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..c58a224965def22a54e8683458daba8b99d21694 --- /dev/null +++ b/indextts/s2mel/modules/gpt_fast/generate.py @@ -0,0 +1,436 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import sys +import time +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch._dynamo.config +import torch._inductor.config + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif ("cpu" in device) or ("mps" in device): + pass + else: + print(f"device={device} is not yet suppported") + + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + +default_device = 'cuda' if torch.cuda.is_available() else 'cpu' + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from model import Transformer +from tokenizer import get_tokenizer + +def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + +def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + +def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + +def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob.clone()) + cur_token = next_token.view(1, -1) + + return new_tokens, new_probs + + +def model_forward(model, x, input_pos): + return model(x, input_pos) + +def speculative_decode( + model: Transformer, + draft_model: Transformer, + cur_token: torch.Tensor, + input_pos: int, + speculate_k: int, + **sampling_kwargs +) -> torch.Tensor: + # draft model inference sequentially + device = cur_token.device + orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device) + draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs) + + draft_tokens = torch.cat(draft_tokens) + # parallel inference on target model using draft tokens + target_logits = model_forward( + model, + torch.cat([cur_token.view(1), draft_tokens]).view(1, -1), + torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device) + ) + target_probs = logits_to_probs(target_logits[0], **sampling_kwargs) + draft_probs = torch.stack(draft_probs) + # q: target prob, p: draft prob + # q >= p: always accept draft token + # q < p: q/p prob to accept draft token + p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens] + q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens] + accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p) + rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero() + + if rejected_locations.shape[0] == 0: # All draft tokens have been accepted + accept_length = speculate_k + 1 + last_token = multinomial_sample_one_no_sync(target_probs[-1]) + # fill last token into draft model + model_forward( + draft_model, + draft_tokens[-1].view(1, -1), + orig_input_pos + speculate_k, + ) + return torch.cat([draft_tokens, last_token]) + else: + accept_length = rejected_locations[0].item() + p = draft_probs[accept_length] + q = target_probs[accept_length] + new = q - p + new = torch.where(new > 0, new, 0.0) + new = new / new.sum() + next_token = multinomial_sample_one_no_sync(new) + return torch.cat([draft_tokens[:accept_length], next_token]) + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + *, + interactive: bool, + draft_model: Transformer, + speculate_k: Optional[int] = 8, + callback = lambda x: x, + **sampling_kwargs +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + + is_speculative = draft_model is not None + # create an empty tensor of the expected final shape and fill in the current tokens + T = prompt.size(0) + T_new = T + max_new_tokens + if interactive: + max_seq_length = 350 + else: + max_seq_length = min(T_new, model.config.block_size) + + device, dtype = prompt.device, prompt.dtype + max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + if is_speculative and draft_model is not model: + draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) + empty[:T] = prompt + seq = empty + input_pos = torch.arange(0, T, device=device) + + next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone() + if is_speculative: + prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs) + seq[T] = next_token + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + accept_counts = [0] * (speculate_k + 1) + + if is_speculative: + input_pos = input_pos.item() # for speculative decoding easier to keep on host + while input_pos < T_new - 1: + cur_token = next_token.view(()) + + next_tokens = speculative_decode( + model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs + ) + + accept_counts[len(next_tokens) - 1] += 1 + num_added = min(T_new - input_pos - 1, len(next_tokens)) + seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added] + for i in next_tokens[: num_added,]: + callback(i) + input_pos = input_pos + num_added + next_token = next_tokens[-1] + else: + generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) + seq[T + 1:] = torch.cat(generated_tokens) + + generate_stats = { + 'accept_counts': accept_counts + } + return seq, generate_stats + +def encode_tokens(tokenizer, string, bos=True, device=default_device): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + +def _load_model(checkpoint_path, device, precision, use_tp): + use_cuda = 'cuda' in device + with torch.device('meta'): + model = Transformer.from_name(checkpoint_path.parent.name) + + if "int8" in str(checkpoint_path): + print("Using int8 weight-only quantization!") + from quantize import WeightOnlyInt8QuantHandler + simple_quantizer = WeightOnlyInt8QuantHandler(model) + model = simple_quantizer.convert_for_runtime() + + if "int4" in str(checkpoint_path): + print("Using int4 weight-only quantization!") + path_comps = checkpoint_path.name.split(".") + groupsize = int(path_comps[-2][1:]) + from quantize import WeightOnlyInt4QuantHandler + simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) + model = simple_quantizer.convert_for_runtime() + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + if "model" in checkpoint and "stories" in str(checkpoint_path): + checkpoint = checkpoint["model"] + model.load_state_dict(checkpoint, assign=True) + + if use_tp: + from tp import apply_tp + print("Applying tensor parallel to model ...") + apply_tp(model) + + model = model.to(device=device, dtype=precision) + return model.eval() + +def _get_model_size(model): + model_size = 0 + for name, child in model.named_children(): + if not isinstance(child, torch.nn.Embedding): + model_size += sum( + [ + p.numel() * p.dtype.itemsize + for p in itertools.chain(child.parameters(), child.buffers()) + ] + ) + return model_size + +B_INST, E_INST = "[INST]", "[/INST]" + +def main( + prompt: str = "Hello, my name is", + interactive: bool = False, + num_samples: int = 5, + max_new_tokens: int = 100, + top_k: int = 200, + temperature: float = 0.8, + checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), + compile: bool = True, + compile_prefill: bool = False, + profile: Optional[Path] = None, + draft_checkpoint_path: Optional[Path] = None, + speculate_k: int = 5, + device=default_device, +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer. + """ + assert checkpoint_path.is_file(), checkpoint_path + + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + + global print + from tp import maybe_init_dist + rank = maybe_init_dist() + use_tp = rank is not None + if use_tp: + if rank != 0: + # only print on rank 0 + print = lambda *args, **kwargs: None + + print(f"Using device={device}") + precision = torch.bfloat16 + is_speculative = draft_checkpoint_path is not None + is_chat = "chat" in str(checkpoint_path) + + print("Loading model ...") + t0 = time.time() + model = _load_model(checkpoint_path, device, precision, use_tp) + + if is_speculative: + draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) + else: + draft_model = None + + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + model_size = _get_model_size(model) + if compile: + if is_speculative and use_tp: # and ("cuda" in device): + torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case + + if is_speculative: + global model_forward, logits_to_prob + model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) + + global decode_one_token, prefill + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + + # Uncomment to squeeze more perf out of prefill + if compile_prefill: + prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + + + aggregate_metrics = { + 'tokens_per_sec': [], + 'accept_counts': [], + } + start = -1 if compile else 0 + + for i in range(start, num_samples): + device_sync(device=device) # MKG + if i >= 0 and interactive: + prompt = input("What is your prompt? ") + if is_chat: + prompt = f"{B_INST} {prompt.strip()} {E_INST}" + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + + if interactive and i >= 0: + buffer = [] + period_id = tokenizer.encode('.')[0] + done_generating = False + def callback(x): + nonlocal done_generating + if done_generating: + return + buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) + if x.item() == tokenizer.eos_id(): + done_generating = True + if len(buffer) == 4 or done_generating: + print(''.join(buffer), end='', flush=True) + buffer.clear() + # print(, end='', flush=True) + else: + callback = lambda x : x + t0 = time.perf_counter() + import contextlib + if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): + prof = contextlib.nullcontext() + else: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + with prof: + y, metrics = generate( + model, + encoded, + max_new_tokens, + draft_model=draft_model, + speculate_k=speculate_k, + interactive=interactive, + callback=callback, + temperature=temperature, + top_k=top_k, + ) + aggregate_metrics['accept_counts'].append(metrics['accept_counts']) + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue + if hasattr(prof, "export_chrome_trace"): + if use_tp: + prof.export_chrome_trace(f"{profile}_rank_{rank}.json") + else: + prof.export_chrome_trace(f"{profile}.json") + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + + if not interactive: + print(tokenizer.decode(y.tolist())) + else: + print() + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics['tokens_per_sec'].append(tokens_sec) + print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + print("==========") + if is_speculative: + counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])] + acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated] + print(f"Acceptance probs: {acceptance_probs}") + print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}") + + print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Your CLI description.') + + parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') + parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') + parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') + parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.') + parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') + parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') + parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)') + parser.add_argument('--profile', type=Path, default=None, help='Profile path.') + parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.') + parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.') + parser.add_argument('--device', type=str, default=default_device, help='Device to use') + + args = parser.parse_args() + main( + args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k, + args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path, + args.speculate_k, args.device + ) diff --git a/indextts/s2mel/modules/gpt_fast/model.py b/indextts/s2mel/modules/gpt_fast/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0f878bc3a9456f5d11446e68a797eb3b19d22bd7 --- /dev/null +++ b/indextts/s2mel/modules/gpt_fast/model.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if embedding is None: + return self.norm(input) + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + has_cross_attention: bool = False + context_dim: int = 0 + uvit_skip_connection: bool = False + time_as_token: bool = False + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + # self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [config for config in transformer_configs if config.lower() in str(name).lower()] + + # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match, + # take longer name (as it have more symbols matched) + if len(config) > 1: + config.sort(key=len, reverse=True) + assert len(config[0]) != len(config[1]), name # make sure only one 'best' match + + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000), + "7B": dict(n_layer=32, n_head=32, dim=4096), + "13B": dict(n_layer=40, n_head=40, dim=5120), + "30B": dict(n_layer=60, n_head=52, dim=6656), + "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, + rope_base=1000000), # CodeLlama-34B-Python-hf + "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), + "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), + "stories15M": dict(n_layer=6, n_head=6, dim=288), + "stories110M": dict(n_layer=12, n_head=12, dim=768), + + "llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, + vocab_size=128256, rope_base=500000), + "llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, + vocab_size=128256, rope_base=500000), +} + + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=True): + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + dtype = self.norm.project_layer.weight.dtype + device = self.norm.project_layer.weight.device + + if not self.training and use_kv_cache: + for b in self.layers: + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype).to(device) + + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, + self.config.rope_base, dtype).to(device) + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device) + self.use_kv_cache = use_kv_cache + self.uvit_skip_connection = self.config.uvit_skip_connection + if self.uvit_skip_connection: + self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2] + self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2] + else: + self.layers_emit_skip = [] + self.layers_receive_skip = [] + + def forward(self, + x: Tensor, + c: Tensor, + input_pos: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + context: Optional[Tensor] = None, + context_input_pos: Optional[Tensor] = None, + cross_attention_mask: Optional[Tensor] = None, + ) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + if mask is None: # in case of non-causal model + if not self.training and self.use_kv_cache: + mask = self.causal_mask[None, None, input_pos] + else: + mask = self.causal_mask[None, None, input_pos] + mask = mask[..., input_pos] + freqs_cis = self.freqs_cis[input_pos] + if context is not None: + context_freqs_cis = self.freqs_cis[context_input_pos] + else: + context_freqs_cis = None + skip_in_x_list = [] + for i, layer in enumerate(self.layers): + if self.uvit_skip_connection and i in self.layers_receive_skip: + skip_in_x = skip_in_x_list.pop(-1) + else: + skip_in_x = None + x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x) + if self.uvit_skip_connection and i in self.layers_emit_skip: + skip_in_x_list.append(x) + x = self.norm(x, c) + return x + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + + if config.has_cross_attention: + self.has_cross_attention = True + self.cross_attention = Attention(config, is_cross_attention=True) + self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + else: + self.has_cross_attention = False + + if config.uvit_skip_connection: + self.skip_in_linear = nn.Linear(config.dim * 2, config.dim) + self.uvit_skip_connection = True + else: + self.uvit_skip_connection = False + + self.time_as_token = config.time_as_token + + def forward(self, + x: Tensor, + c: Tensor, + input_pos: Tensor, + freqs_cis: Tensor, + mask: Tensor, + context: Optional[Tensor] = None, + context_freqs_cis: Optional[Tensor] = None, + cross_attention_mask: Optional[Tensor] = None, + skip_in_x: Optional[Tensor] = None, + ) -> Tensor: + c = None if self.time_as_token else c + if self.uvit_skip_connection and skip_in_x is not None: + x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1)) + h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos) + if self.has_cross_attention: + h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis) + out = h + self.feed_forward(self.ffn_norm(h, c)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs, is_cross_attention: bool = False): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + if is_cross_attention: + self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False) + self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False) + else: + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + # self._register_load_state_dict_pre_hook(self.load_hook) + + # def load_hook(self, state_dict, prefix, *args): + # if prefix + "wq.weight" in state_dict: + # wq = state_dict.pop(prefix + "wq.weight") + # wk = state_dict.pop(prefix + "wk.weight") + # wv = state_dict.pop(prefix + "wv.weight") + # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward(self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + context: Optional[Tensor] = None, + context_freqs_cis: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + if context is None: + q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) + context_seqlen = seqlen + else: + q = self.wq(x) + k, v = self.wkv(context).split([kv_size, kv_size], dim=-1) + context_seqlen = context.shape[1] + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head) + + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000, + dtype: torch.dtype = torch.bfloat16 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/indextts/s2mel/modules/gpt_fast/quantize.py b/indextts/s2mel/modules/gpt_fast/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..fb566421f68f84aeca22c2ab91f5bff53c819c13 --- /dev/null +++ b/indextts/s2mel/modules/gpt_fast/quantize.py @@ -0,0 +1,622 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import time +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tokenizer import get_tokenizer + +try: + from GPTQ import GenericGPTQRunner, InputRecorder + from eval import get_task_dict, evaluate, lm_eval +except: + pass + +from model import Transformer + +##### Quantization Primitives ###### + +def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scales and zero_points based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scales = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scales is the same dtype as the original tensor + scales = torch.clamp(scales, min=eps).to(x.dtype) + zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scales/zp + # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x / scales.unsqueeze(-1) + x_round = torch.round(x_div) + x_zp = x_round + zero_points.unsqueeze(-1) + quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + + return quant, scales, zero_points + +def get_group_qparams(w, n_bit=4, groupsize=128): + # needed for GPTQ with padding + if groupsize > w.shape[-1]: + groupsize = w.shape[-1] + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + zeros = min_val + scales * (2 ** (n_bit - 1)) + return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to( + torch.bfloat16 + ).reshape(w.shape[0], -1) + + +def pack_scales_and_zeros(scales, zeros): + assert scales.shape == zeros.shape + assert scales.dtype == torch.bfloat16 + assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + +def unpack_scales_and_zeros(scales_and_zeros): + assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 + assert scales_and_zeros.dtype == torch.float + return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) + + +def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): + assert groupsize > 1 + # needed for GPTQ single column quantize + if groupsize > w.shape[-1] and scales.shape[-1] == 1: + groupsize = w.shape[-1] + + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + min_val = zeros - scales * (2 ** (n_bit - 1)) + max_int = 2**n_bit - 1 + min_int = 0 + w_int32 = ( + to_quant.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape_as(w) + ) + + return w_int32 + + +def group_quantize_tensor(w, n_bit=4, groupsize=128): + scales, zeros = get_group_qparams(w, n_bit, groupsize) + w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) + scales_and_zeros = pack_scales_and_zeros(scales, zeros) + return w_int32, scales_and_zeros + + +def group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit=4, groupsize=128 +): + assert groupsize > 1 + # needed for GPTQ single column dequantize + if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: + groupsize = w_int32.shape[-1] + assert w_int32.shape[-1] % groupsize == 0 + assert w_int32.dim() == 2 + + w_int32_grouped = w_int32.reshape(-1, groupsize) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + + w_dq = ( + w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) + ) + return w_dq + + +def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): + scales, zeros = unpack_scales_and_zeros(scales_and_zeros) + return group_dequantize_tensor_from_qparams( + w_int32, scales, zeros, n_bit, groupsize + ) + +class QuantHandler: + def __init__(self, mod): + self.mod = mod + + def create_quantized_state_dict(self) -> "StateDict": + pass + + def convert_for_runtime(self) -> "nn.Module": + pass + +class GPTQQuantHandler(QuantHandler): + """ + This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. + Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement + __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime. + + The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and + create_quantized_state_dict. Here is a description of each function. + + get_qparams_func: + A function that calculates the quantization qparams for an input tensor. + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + qparams: it can have any format but will need to be handled by the other defined functions below. + + quantize_func: + A function that applies quantization to an input tensor. It should be noted + that this function needs to be able to handle quantizing the entire weight tensor, a single group, + or a single column. + Args: + weight: A 2d weight tensor with non-integer dtype. + qparams: the output from get_qparams_func + Returns: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + + + dequantize_func: + A function that dequantizes an input quantized weight tensor. It should be noted + that this function needs to be able to handle dequantizing the entire weight tensor, a single group, + or a single column. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + weight: A 2d weight tensor with non-integer dtype. + + combine_qparams_list_func: + A function that combines several qparams into one qparam. + Args: + qparams_list: a list of qparams objects, each obtained by calling get_qparams_func + on a single group from a weight tensor + Returns: + qparams: an object of the same format as the qparams above. + + skip_layer_func: + A function that determines which linear layers should be skipped during GPTQ + Args: + weight: A 2d weight tensor with non-integer dtype. + Returns: + skip: boolean indicating whether layer should be skipped + + make_names_and_values_dict_func: + A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they + should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here. + Args: + quantized_weight: A 2d quantized weight tensor (generally with an integer dtype) + qparams: the output from get_qparams_func + Returns: + names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the + corresponding quantized weights and qparams. + """ + def __init__(self): + assert self.mod is not None + assert self.get_qparams_func is not None + assert self.quantize_func is not None + assert self.dequantize_func is not None + assert self.combine_qparams_list_func is not None + assert self.make_names_and_values_dict_func is not None + + @staticmethod + def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput": + input_recorder = InputRecorder( + model, + tokenizer, + calibration_seq_length, + pad_calibration_inputs, + ) + + try: + lm_eval.tasks.initialize_tasks() + except: + pass + task_dict = get_task_dict(calibration_tasks) + print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) + + evaluate( + input_recorder, + task_dict, + limit=calibration_limit, + ) + inputs = input_recorder.get_recorded_inputs() + assert inputs is not None, ( + f"No inputs were collected, use a task other than {calibration_tasks}, "+ + f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+ + f"{calibration_seq_length})" + ) + print(f"Obtained {len(inputs[0].values)} calibration samples") + return inputs + + @torch.no_grad() + def create_quantized_state_dict( + self, + tokenizer, + blocksize, + percdamp, + groupsize, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) -> "StateDict": + inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) + print("Tracing model for GPTQ") + GPTQ_runner = GenericGPTQRunner( + self.mod, + inputs, + blocksize, + percdamp, + groupsize, + ).configure_quantization_mode( + self.get_qparams_func, + self.quantize_func, + self.dequantize_func, + self.combine_qparams_list_func, + self.make_names_and_values_dict_func, + self.skip_layer_func + ) + + print("Applying GPTQ to weights") + GPTQ_runner.run() + return GPTQ_runner.get_quantized_state_dict() + + def convert_for_runtime(self) -> "nn.Module": + pass + +##### Weight-only int8 per-channel quantized code ###### + +def replace_linear_weight_only_int8_per_channel(module): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features)) + else: + replace_linear_weight_only_int8_per_channel(child) + +class WeightOnlyInt8QuantHandler: + def __init__(self, mod): + self.mod = mod + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8) + cur_state_dict[f"{fqn}.weight"] = int8_weight + cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype) + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_weight_only_int8_per_channel(self.mod) + return self.mod + + +class WeightOnlyInt8Linear(torch.nn.Module): + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) + self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + +##### weight only int4 per channel groupwise quantized code ###### + +def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): + weight_int32, scales_and_zeros = group_quantize_tensor( + weight_bf16, n_bit=4, groupsize=groupsize + ) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) + return weight_int4pack, scales_and_zeros + + +def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + + +def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1): + return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): + setattr(module, name, WeightOnlyInt4Linear( + child.in_features, child.out_features, bias=False, + groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False, + )) + elif padding: + setattr(module, name, WeightOnlyInt4Linear( + child.in_features, child.out_features, bias=False, + groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True, + )) + else: + replace_linear_int4(child, groupsize, inner_k_tiles, padding) + + +class WeightOnlyInt4QuantHandler: + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + self.mod = mod + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding = padding + assert groupsize in [32, 64, 128, 256] + assert inner_k_tiles in [2, 4, 8] + + @torch.no_grad() + def create_quantized_state_dict(self, use_cuda = True): + if use_cuda: + device="cuda" + else: + device="cpu" + + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + weight = mod.weight.data + if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): + if self.padding: + from model import find_multiple + import torch.nn.functional as F + print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + else: + print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it") + continue + weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros( + weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles + ) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu') + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu') + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) + return self.mod + +class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + from model import find_multiple + self.mod = mod + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding = padding + self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize) + self.quantize_func = lambda w, qparams: \ + group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize) + self.dequantize_func = lambda q, qparams: \ + group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float() + self.combine_qparams_list_func = lambda qparams_list: \ + [torch.cat(x, dim=1) for x in zip(*qparams_list)] + # skip unless padding=True or its correctly sized + self.skip_layer_func = lambda linear_weight: not ( + _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding + ) + # we need to do the padding here, both for q and the qparams if necessary + def make_names_and_values_dict_func(q, qparams): + k = q.shape[1] + new_k = find_multiple(k, 1024) + # how much we need to pad the weight + delta_k = new_k - q.shape[1] + final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) + scales_and_zeros = pack_scales_and_zeros(*qparams) + # how many new groups we need for padded weight + delta_groups = new_k // groupsize - scales_and_zeros.shape[0] + final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) + return {"weight": final_q, "scales_and_zeros": final_s_and_z} + self.make_names_and_values_dict_func = make_names_and_values_dict_func + super().__init__() + + + def convert_for_runtime(self): + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding) + return self.mod + +class WeightOnlyInt4Linear(torch.nn.Module): + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, in_features: int, out_features: int, + bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True, + ) -> None: + super().__init__() + self.padding = padding + if padding: + from model import find_multiple + self.origin_in_features = in_features + in_features = find_multiple(in_features, 1024) + + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + + assert out_features % 8 == 0, "require out_features % 8 == 0" + assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" + self.register_buffer( + "weight", + torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + ) + self.register_buffer( + "scales_and_zeros", + torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(torch.bfloat16) + if self.padding: + import torch.nn.functional as F + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + return linear_forward_int4( + input, + self.weight, self.scales_and_zeros, self.out_features, self.groupsize + ) + + +def quantize( + checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), + mode: str = 'int8', + # following arguments only available when setting int4 quantization. + groupsize: int = 128, + # following arguments only used for GPTQ + calibration_tasks: list = ["hellaswag"], + calibration_limit: int = 1000, + calibration_seq_length: int = 100, + pad_calibration_inputs: bool = False, + percdamp: float = .01, + blocksize: int = 128, + label: str = '', +) -> None: + assert checkpoint_path.is_file(), checkpoint_path + + device = 'cpu' + precision = torch.bfloat16 + + print("Loading model ...") + t0 = time.time() + + with torch.device('meta'): + model = Transformer.from_name(checkpoint_path.parent.name) + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device=device) + + if mode == 'int8': + print("Quantizing model weights for int8 weight-only symmetric per-channel quantization") + quant_handler = WeightOnlyInt8QuantHandler(model) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace('.pth', f'{label}int8.pth') + + elif mode == 'int4': + print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization") + quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) + quantized_state_dict = quant_handler.create_quantized_state_dict() + + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth") + + elif mode == 'int4-gptq': + print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...") + quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize) + + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + tokenizer = get_tokenizer(tokenizer_path, checkpoint_path) + + quantized_state_dict = quant_handler.create_quantized_state_dict( + tokenizer, + blocksize, + percdamp, + groupsize, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs + ) + + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth") + else: + raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]") + + quantize_path = dir_name / new_base_name + print(f"Writing quantized weights to {quantize_path}") + quantize_path.unlink(missing_ok=True) # remove existing file if one already there + torch.save(quantized_state_dict, quantize_path) + print(f"Quantization complete took {time.time() - t0:.02f} seconds") + return + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Quantize a model.') + parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') + parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') + parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.') + parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') + parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration') + parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration') + parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower') + parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening') + parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq') + parser.add_argument('--label', type=str, default='_', help='label to add to output filename') + + args = parser.parse_args() + quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label) diff --git a/indextts/s2mel/modules/hifigan/f0_predictor.py b/indextts/s2mel/modules/hifigan/f0_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..36b85f4ed90c3a412cb179f49ccb471132a86550 --- /dev/null +++ b/indextts/s2mel/modules/hifigan/f0_predictor.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm + + +class ConvRNNF0Predictor(nn.Module): + def __init__(self, + num_class: int = 1, + in_channels: int = 80, + cond_channels: int = 512 + ): + super().__init__() + + self.num_class = num_class + self.condnet = nn.Sequential( + weight_norm( + nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + weight_norm( + nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) + ), + nn.ELU(), + ) + self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.condnet(x) + x = x.transpose(1, 2) + return torch.abs(self.classifier(x).squeeze(-1)) diff --git a/indextts/s2mel/modules/hifigan/generator.py b/indextts/s2mel/modules/hifigan/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..867894cee55a0746b93c1ba070871b7e5e5eba2b --- /dev/null +++ b/indextts/s2mel/modules/hifigan/generator.py @@ -0,0 +1,454 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HIFI-GAN""" + +import typing as tp +import numpy as np +from scipy.signal import get_window +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d +from torch.nn import ConvTranspose1d +from torch.nn.utils import remove_weight_norm +from torch.nn.utils import weight_norm +from torch.distributions.uniform import Uniform + +from torch import sin +from torch.nn.parameter import Parameter + + +"""hifigan based generator implementation. + +This code is modified from https://github.com/jik876/hifi-gan + ,https://github.com/kan-bayashi/ParallelWaveGAN and + https://github.com/NVIDIA/BigVGAN + +""" +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + + +class ResBlock(torch.nn.Module): + """Residual block module in HiFiGAN/BigVGAN.""" + def __init__( + self, + channels: int = 512, + kernel_size: int = 3, + dilations: tp.List[int] = [1, 3, 5], + ): + super(ResBlock, self).__init__() + self.convs1 = nn.ModuleList() + self.convs2 = nn.ModuleList() + + for dilation in dilations: + self.convs1.append( + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + padding=get_padding(kernel_size, dilation) + ) + ) + ) + self.convs2.append( + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1) + ) + ) + ) + self.convs1.apply(init_weights) + self.convs2.apply(init_weights) + self.activations1 = nn.ModuleList([ + Snake(channels, alpha_logscale=False) + for _ in range(len(self.convs1)) + ]) + self.activations2 = nn.ModuleList([ + Snake(channels, alpha_logscale=False) + for _ in range(len(self.convs2)) + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for idx in range(len(self.convs1)): + xt = self.activations1[idx](x) + xt = self.convs1[idx](xt) + xt = self.activations2[idx](xt) + xt = self.convs2[idx](xt) + x = xt + x + return x + + def remove_weight_norm(self): + for idx in range(len(self.convs1)): + remove_weight_norm(self.convs1[idx]) + remove_weight_norm(self.convs2[idx]) + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + @torch.no_grad() + def forward(self, f0): + """ + :param f0: [B, 1, sample_len], Hz + :return: [B, 1, sample_len] + """ + + F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device) + for i in range(self.harmonic_num + 1): + F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate + + theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1) + u_dist = Uniform(low=-np.pi, high=np.pi) + phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device) + phase_vec[:, 0, :] = 0 + + # generate sine waveforms + sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec) + + # generate uv signal + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + with torch.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2)) + sine_wavs = sine_wavs.transpose(1, 2) + uv = uv.transpose(1, 2) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class HiFTGenerator(nn.Module): + """ + HiFTNet Generator: Neural Source Filter + ISTFTNet + https://arxiv.org/abs/2309.09493 + """ + def __init__( + self, + in_channels: int = 80, + base_channels: int = 512, + nb_harmonics: int = 8, + sampling_rate: int = 22050, + nsf_alpha: float = 0.1, + nsf_sigma: float = 0.003, + nsf_voiced_threshold: float = 10, + upsample_rates: tp.List[int] = [8, 8], + upsample_kernel_sizes: tp.List[int] = [16, 16], + istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4}, + resblock_kernel_sizes: tp.List[int] = [3, 7, 11], + resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + source_resblock_kernel_sizes: tp.List[int] = [7, 11], + source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]], + lrelu_slope: float = 0.1, + audio_limit: float = 0.99, + f0_predictor: torch.nn.Module = None, + ): + super(HiFTGenerator, self).__init__() + + self.out_channels = 1 + self.nb_harmonics = nb_harmonics + self.sampling_rate = sampling_rate + self.istft_params = istft_params + self.lrelu_slope = lrelu_slope + self.audio_limit = audio_limit + + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.m_source = SourceModuleHnNSF( + sampling_rate=sampling_rate, + upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"], + harmonic_num=nb_harmonics, + sine_amp=nsf_alpha, + add_noise_std=nsf_sigma, + voiced_threshod=nsf_voiced_threshold) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"]) + + self.conv_pre = weight_norm( + Conv1d(in_channels, base_channels, 7, 1, padding=3) + ) + + # Up + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + base_channels // (2**i), + base_channels // (2**(i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + # Down + self.source_downs = nn.ModuleList() + self.source_resblocks = nn.ModuleList() + downsample_rates = [1] + upsample_rates[::-1][:-1] + downsample_cum_rates = np.cumprod(downsample_rates) + for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, + source_resblock_dilation_sizes)): + if u == 1: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1) + ) + else: + self.source_downs.append( + Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2)) + ) + + self.source_resblocks.append( + ResBlock(base_channels // (2 ** (i + 1)), k, d) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = base_channels // (2**(i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(ResBlock(ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.reflection_pad = nn.ReflectionPad1d((1, 0)) + self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32)) + self.f0_predictor = f0_predictor + + def _f02source(self, f0: torch.Tensor) -> torch.Tensor: + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + + har_source, _, _ = self.m_source(f0) + return har_source.transpose(1, 2) + + def _stft(self, x): + spec = torch.stft( + x, + self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device), + return_complex=True) + spec = torch.view_as_real(spec) # [B, F, TT, 2] + return spec[..., 0], spec[..., 1] + + def _istft(self, magnitude, phase): + magnitude = torch.clip(magnitude, max=1e2) + real = magnitude * torch.cos(phase) + img = magnitude * torch.sin(phase) + inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) + return inverse_transform + + def forward(self, x: torch.Tensor, f0=None) -> torch.Tensor: + if f0 is None: + f0 = self.f0_predictor(x) + s = self._f02source(f0) + + s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) + s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) + + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, self.lrelu_slope) + x = self.ups[i](x) + + if i == self.num_upsamples - 1: + x = self.reflection_pad(x) + + # fusion + si = self.source_downs[i](s_stft) + si = self.source_resblocks[i](si) + x = x + si + + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = F.leaky_relu(x) + x = self.conv_post(x) + magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) + phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy + + x = self._istft(magnitude, phase) + x = torch.clamp(x, -self.audio_limit, self.audio_limit) + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + self.source_module.remove_weight_norm() + for l in self.source_downs: + remove_weight_norm(l) + for l in self.source_resblocks: + l.remove_weight_norm() + + @torch.inference_mode() + def inference(self, mel: torch.Tensor, f0=None) -> torch.Tensor: + return self.forward(x=mel, f0=f0) diff --git a/indextts/s2mel/modules/layers.py b/indextts/s2mel/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..bc6567b47ef60e9a64f85b7c088b7fac1683fdb5 --- /dev/null +++ b/indextts/s2mel/modules/layers.py @@ -0,0 +1,354 @@ +import math +import torch +from torch import nn +from typing import Optional, Any +from torch import Tensor +import torch.nn.functional as F +import torchaudio +import torchaudio.functional as audio_F + +import random +random.seed(0) + + +def _get_activation_fn(activ): + if activ == 'relu': + return nn.ReLU() + elif activ == 'lrelu': + return nn.LeakyReLU(0.2) + elif activ == 'swish': + return lambda x: x*torch.sigmoid(x) + else: + raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ) + +class LinearNorm(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + super(LinearNorm, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class ConvNorm(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, + padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): + super(ConvNorm, self).__init__() + if padding is None: + assert(kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = torch.nn.Conv1d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + bias=bias) + + torch.nn.init.xavier_uniform_( + self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) + + def forward(self, signal): + conv_signal = self.conv(signal) + return conv_signal + +class CausualConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None): + super(CausualConv, self).__init__() + if padding is None: + assert(kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1) / 2) * 2 + else: + self.padding = padding * 2 + self.conv = nn.Conv1d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + padding=self.padding, + dilation=dilation, + bias=bias) + + torch.nn.init.xavier_uniform_( + self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) + + def forward(self, x): + x = self.conv(x) + x = x[:, :, :-self.padding] + return x + +class CausualBlock(nn.Module): + def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'): + super(CausualBlock, self).__init__() + self.blocks = nn.ModuleList([ + self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) + for i in range(n_conv)]) + + def forward(self, x): + for block in self.blocks: + res = x + x = block(x) + x += res + return x + + def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2): + layers = [ + CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), + _get_activation_fn(activ), + nn.BatchNorm1d(hidden_dim), + nn.Dropout(p=dropout_p), + CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), + _get_activation_fn(activ), + nn.Dropout(p=dropout_p) + ] + return nn.Sequential(*layers) + +class ConvBlock(nn.Module): + def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'): + super().__init__() + self._n_groups = 8 + self.blocks = nn.ModuleList([ + self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) + for i in range(n_conv)]) + + + def forward(self, x): + for block in self.blocks: + res = x + x = block(x) + x += res + return x + + def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2): + layers = [ + ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), + _get_activation_fn(activ), + nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), + nn.Dropout(p=dropout_p), + ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), + _get_activation_fn(activ), + nn.Dropout(p=dropout_p) + ] + return nn.Sequential(*layers) + +class LocationLayer(nn.Module): + def __init__(self, attention_n_filters, attention_kernel_size, + attention_dim): + super(LocationLayer, self).__init__() + padding = int((attention_kernel_size - 1) / 2) + self.location_conv = ConvNorm(2, attention_n_filters, + kernel_size=attention_kernel_size, + padding=padding, bias=False, stride=1, + dilation=1) + self.location_dense = LinearNorm(attention_n_filters, attention_dim, + bias=False, w_init_gain='tanh') + + def forward(self, attention_weights_cat): + processed_attention = self.location_conv(attention_weights_cat) + processed_attention = processed_attention.transpose(1, 2) + processed_attention = self.location_dense(processed_attention) + return processed_attention + + +class Attention(nn.Module): + def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + attention_location_n_filters, attention_location_kernel_size): + super(Attention, self).__init__() + self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, + bias=False, w_init_gain='tanh') + self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, + w_init_gain='tanh') + self.v = LinearNorm(attention_dim, 1, bias=False) + self.location_layer = LocationLayer(attention_location_n_filters, + attention_location_kernel_size, + attention_dim) + self.score_mask_value = -float("inf") + + def get_alignment_energies(self, query, processed_memory, + attention_weights_cat): + """ + PARAMS + ------ + query: decoder output (batch, n_mel_channels * n_frames_per_step) + processed_memory: processed encoder outputs (B, T_in, attention_dim) + attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) + RETURNS + ------- + alignment (batch, max_time) + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v(torch.tanh( + processed_query + processed_attention_weights + processed_memory)) + + energies = energies.squeeze(-1) + return energies + + def forward(self, attention_hidden_state, memory, processed_memory, + attention_weights_cat, mask): + """ + PARAMS + ------ + attention_hidden_state: attention rnn last output + memory: encoder outputs + processed_memory: processed encoder outputs + attention_weights_cat: previous and cummulative attention weights + mask: binary mask for padded data + """ + alignment = self.get_alignment_energies( + attention_hidden_state, processed_memory, attention_weights_cat) + + if mask is not None: + alignment.data.masked_fill_(mask, self.score_mask_value) + + attention_weights = F.softmax(alignment, dim=1) + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights + + +class ForwardAttentionV2(nn.Module): + def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + attention_location_n_filters, attention_location_kernel_size): + super(ForwardAttentionV2, self).__init__() + self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, + bias=False, w_init_gain='tanh') + self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, + w_init_gain='tanh') + self.v = LinearNorm(attention_dim, 1, bias=False) + self.location_layer = LocationLayer(attention_location_n_filters, + attention_location_kernel_size, + attention_dim) + self.score_mask_value = -float(1e20) + + def get_alignment_energies(self, query, processed_memory, + attention_weights_cat): + """ + PARAMS + ------ + query: decoder output (batch, n_mel_channels * n_frames_per_step) + processed_memory: processed encoder outputs (B, T_in, attention_dim) + attention_weights_cat: prev. and cumulative att weights (B, 2, max_time) + RETURNS + ------- + alignment (batch, max_time) + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v(torch.tanh( + processed_query + processed_attention_weights + processed_memory)) + + energies = energies.squeeze(-1) + return energies + + def forward(self, attention_hidden_state, memory, processed_memory, + attention_weights_cat, mask, log_alpha): + """ + PARAMS + ------ + attention_hidden_state: attention rnn last output + memory: encoder outputs + processed_memory: processed encoder outputs + attention_weights_cat: previous and cummulative attention weights + mask: binary mask for padded data + """ + log_energy = self.get_alignment_energies( + attention_hidden_state, processed_memory, attention_weights_cat) + + #log_energy = + + if mask is not None: + log_energy.data.masked_fill_(mask, self.score_mask_value) + + #attention_weights = F.softmax(alignment, dim=1) + + #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME] + #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1] + + #log_total_score = log_alpha + content_score + + #previous_attention_weights = attention_weights_cat[:,0,:] + + log_alpha_shift_padded = [] + max_time = log_energy.size(1) + for sft in range(2): + shifted = log_alpha[:,:max_time-sft] + shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value) + log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) + + biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2) + + log_alpha_new = biased + log_energy + + attention_weights = F.softmax(log_alpha_new, dim=1) + + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights, log_alpha_new + + +class PhaseShuffle2d(nn.Module): + def __init__(self, n=2): + super(PhaseShuffle2d, self).__init__() + self.n = n + self.random = random.Random(1) + + def forward(self, x, move=None): + # x.size = (B, C, M, L) + if move is None: + move = self.random.randint(-self.n, self.n) + + if move == 0: + return x + else: + left = x[:, :, :, :move] + right = x[:, :, :, move:] + shuffled = torch.cat([right, left], dim=3) + return shuffled + +class PhaseShuffle1d(nn.Module): + def __init__(self, n=2): + super(PhaseShuffle1d, self).__init__() + self.n = n + self.random = random.Random(1) + + def forward(self, x, move=None): + # x.size = (B, C, M, L) + if move is None: + move = self.random.randint(-self.n, self.n) + + if move == 0: + return x + else: + left = x[:, :, :move] + right = x[:, :, move:] + shuffled = torch.cat([right, left], dim=2) + + return shuffled + +class MFCC(nn.Module): + def __init__(self, n_mfcc=40, n_mels=80): + super(MFCC, self).__init__() + self.n_mfcc = n_mfcc + self.n_mels = n_mels + self.norm = 'ortho' + dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm) + self.register_buffer('dct_mat', dct_mat) + + def forward(self, mel_specgram): + if len(mel_specgram.shape) == 2: + mel_specgram = mel_specgram.unsqueeze(0) + unsqueezed = True + else: + unsqueezed = False + # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) + # -> (channel, time, n_mfcc).tranpose(...) + mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) + + # unpack batch + if unsqueezed: + mfcc = mfcc.squeeze(0) + return mfcc diff --git a/indextts/s2mel/modules/length_regulator.py b/indextts/s2mel/modules/length_regulator.py new file mode 100644 index 0000000000000000000000000000000000000000..253c35ebc720626f5c2b27d83243301c33a3e0d2 --- /dev/null +++ b/indextts/s2mel/modules/length_regulator.py @@ -0,0 +1,141 @@ +from typing import Tuple +import torch +import torch.nn as nn +from torch.nn import functional as F +from indextts.s2mel.modules.commons import sequence_mask +import numpy as np +from indextts.s2mel.dac.nn.quantize import VectorQuantize + +# f0_bin = 256 +f0_max = 1100.0 +f0_min = 50.0 +f0_mel_min = 1127 * np.log(1 + f0_min / 700) +f0_mel_max = 1127 * np.log(1 + f0_max / 700) + +def f0_to_coarse(f0, f0_bin): + f0_mel = 1127 * (1 + f0 / 700).log() + a = (f0_bin - 2) / (f0_mel_max - f0_mel_min) + b = f0_mel_min * a - 1. + f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel) + # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1)) + f0_coarse = torch.round(f0_mel).long() + f0_coarse = f0_coarse * (f0_coarse > 0) + f0_coarse = f0_coarse + ((f0_coarse < 1) * 1) + f0_coarse = f0_coarse * (f0_coarse < f0_bin) + f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1)) + return f0_coarse + +class InterpolateRegulator(nn.Module): + def __init__( + self, + channels: int, + sampling_ratios: Tuple, + is_discrete: bool = False, + in_channels: int = None, # only applies to continuous input + vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input + codebook_size: int = 1024, # for discrete only + out_channels: int = None, + groups: int = 1, + n_codebooks: int = 1, # number of codebooks + quantizer_dropout: float = 0.0, # dropout for quantizer + f0_condition: bool = False, + n_f0_bins: int = 512, + ): + super().__init__() + self.sampling_ratios = sampling_ratios + out_channels = out_channels or channels + model = nn.ModuleList([]) + if len(sampling_ratios) > 0: + self.interpolate = True + for _ in sampling_ratios: + module = nn.Conv1d(channels, channels, 3, 1, 1) + norm = nn.GroupNorm(groups, channels) + act = nn.Mish() + model.extend([module, norm, act]) + else: + self.interpolate = False + model.append( + nn.Conv1d(channels, out_channels, 1, 1) + ) + self.model = nn.Sequential(*model) + self.embedding = nn.Embedding(codebook_size, channels) + self.is_discrete = is_discrete + + self.mask_token = nn.Parameter(torch.zeros(1, channels)) + + self.n_codebooks = n_codebooks + if n_codebooks > 1: + self.extra_codebooks = nn.ModuleList([ + nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1) + ]) + self.extra_codebook_mask_tokens = nn.ParameterList([ + nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1) + ]) + self.quantizer_dropout = quantizer_dropout + + if f0_condition: + self.f0_embedding = nn.Embedding(n_f0_bins, channels) + self.f0_condition = f0_condition + self.n_f0_bins = n_f0_bins + self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins) + self.f0_mask = nn.Parameter(torch.zeros(1, channels)) + else: + self.f0_condition = False + + if not is_discrete: + self.content_in_proj = nn.Linear(in_channels, channels) + if vector_quantize: + self.vq = VectorQuantize(channels, codebook_size, 8) + + def forward(self, x, ylens=None, n_quantizers=None, f0=None): + # apply token drop + if self.training: + n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks + dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],)) + n_dropout = int(x.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(x.device) + # decide whether to drop for each sample in batch + else: + n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers) + if self.is_discrete: + if self.n_codebooks > 1: + assert len(x.size()) == 3 + x_emb = self.embedding(x[:, 0]) + for i, emb in enumerate(self.extra_codebooks): + x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1]) + # add mask token if not using this codebook + # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i] + x = x_emb + elif self.n_codebooks == 1: + if len(x.size()) == 2: + x = self.embedding(x) + else: + x = self.embedding(x[:, 0]) + else: + x = self.content_in_proj(x) + # x in (B, T, D) + mask = sequence_mask(ylens).unsqueeze(-1) + if self.interpolate: + x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest') + else: + x = x.transpose(1, 2).contiguous() + mask = mask[:, :x.size(2), :] + ylens = ylens.clamp(max=x.size(2)).long() + if self.f0_condition: + if f0 is None: + x = x + self.f0_mask.unsqueeze(-1) + else: + #quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T) + quantized_f0 = f0_to_coarse(f0, self.n_f0_bins) + quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long() + f0_emb = self.f0_embedding(quantized_f0) + f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest') + x = x + f0_emb + out = self.model(x).transpose(1, 2).contiguous() + if hasattr(self, 'vq'): + out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2)) + out_q = out_q.transpose(1, 2) + return out_q * mask, ylens, codes, commitment_loss, codebook_loss + olens = ylens + return out * mask, olens, None, None, None diff --git a/indextts/s2mel/modules/openvoice/__init__.py b/indextts/s2mel/modules/openvoice/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indextts/s2mel/modules/openvoice/api.py b/indextts/s2mel/modules/openvoice/api.py new file mode 100644 index 0000000000000000000000000000000000000000..424bb4d8a262f71e7692528f0a05e06ca5f7f982 --- /dev/null +++ b/indextts/s2mel/modules/openvoice/api.py @@ -0,0 +1,186 @@ +import torch +import numpy as np +import re +import soundfile +from . import utils +from . import commons +import os +import librosa +# from openvoice.text import text_to_sequence +from .mel_processing import spectrogram_torch +from .models import SynthesizerTrn + + +class OpenVoiceBaseClass(object): + def __init__(self, + config_path, + device='cuda:0'): + if 'cuda' in device: + assert torch.cuda.is_available() + + hps = utils.get_hparams_from_file(config_path) + + model = SynthesizerTrn( + len(getattr(hps, 'symbols', [])), + hps.data.filter_length // 2 + 1, + n_speakers=hps.data.n_speakers, + **hps.model, + ).to(device) + + model.eval() + self.model = model + self.hps = hps + self.device = device + + def load_ckpt(self, ckpt_path): + checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device)) + a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False) + print("Loaded checkpoint '{}'".format(ckpt_path)) + print('missing/unexpected keys:', a, b) + + +class BaseSpeakerTTS(OpenVoiceBaseClass): + language_marks = { + "english": "EN", + "chinese": "ZH", + } + + @staticmethod + def get_text(text, hps, is_symbol): + text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) + if hps.data.add_blank: + text_norm = commons.intersperse(text_norm, 0) + text_norm = torch.LongTensor(text_norm) + return text_norm + + @staticmethod + def audio_numpy_concat(segment_data_list, sr, speed=1.): + audio_segments = [] + for segment_data in segment_data_list: + audio_segments += segment_data.reshape(-1).tolist() + audio_segments += [0] * int((sr * 0.05)/speed) + audio_segments = np.array(audio_segments).astype(np.float32) + return audio_segments + + @staticmethod + def split_sentences_into_pieces(text, language_str): + texts = utils.split_sentence(text, language_str=language_str) + print(" > Text splitted to sentences.") + print('\n'.join(texts)) + print(" > ===========================") + return texts + + def tts(self, text, output_path, speaker, language='English', speed=1.0): + mark = self.language_marks.get(language.lower(), None) + assert mark is not None, f"language {language} is not supported" + + texts = self.split_sentences_into_pieces(text, mark) + + audio_list = [] + for t in texts: + t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t) + t = f'[{mark}]{t}[{mark}]' + stn_tst = self.get_text(t, self.hps, False) + device = self.device + speaker_id = self.hps.speakers[speaker] + with torch.no_grad(): + x_tst = stn_tst.unsqueeze(0).to(device) + x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) + sid = torch.LongTensor([speaker_id]).to(device) + audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6, + length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy() + audio_list.append(audio) + audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed) + + if output_path is None: + return audio + else: + soundfile.write(output_path, audio, self.hps.data.sampling_rate) + + +class ToneColorConverter(OpenVoiceBaseClass): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # if kwargs.get('enable_watermark', True): + # import wavmark + # self.watermark_model = wavmark.load_model().to(self.device) + # else: + # self.watermark_model = None + self.version = getattr(self.hps, '_version_', "v1") + + + + def extract_se(self, waves, wave_lengths): + + device = self.device + hps = self.hps + gs = [] + + for wav_tensor, wav_len in zip(waves, wave_lengths): + y = wav_tensor[:wav_len] + y = y[None, :] + y = spectrogram_torch(y, hps.data.filter_length, + hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, + center=False).to(device) + with torch.no_grad(): + g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1) + gs.append(g.detach()) + gs = torch.stack(gs) + gs = gs.squeeze(1).squeeze(-1) + return gs + + def convert(self, src_waves, src_wave_lengths, src_se, tgt_se, tau=0.3, message="default"): + hps = self.hps + # load audio + with torch.no_grad(): + y = src_waves + spec = spectrogram_torch(y, hps.data.filter_length, + hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, + center=False).to(self.device) + spec_lengths = src_wave_lengths // hps.data.hop_length + spec_lengths = spec_lengths.clamp(min=1, max=spec.size(2)) + audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se.unsqueeze(-1), sid_tgt=tgt_se.unsqueeze(-1), tau=tau)[0] + return audio + + def add_watermark(self, audio, message): + # if self.watermark_model is None: + return audio + device = self.device + bits = utils.string_to_bits(message).reshape(-1) + n_repeat = len(bits) // 32 + + K = 16000 + coeff = 2 + for n in range(n_repeat): + trunck = audio[(coeff * n) * K: (coeff * n + 1) * K] + if len(trunck) != K: + print('Audio too short, fail to add watermark') + break + message_npy = bits[n * 32: (n + 1) * 32] + + with torch.no_grad(): + signal = torch.FloatTensor(trunck).to(device)[None] + message_tensor = torch.FloatTensor(message_npy).to(device)[None] + signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor) + signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze() + audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy + return audio + + def detect_watermark(self, audio, n_repeat): + bits = [] + K = 16000 + coeff = 2 + for n in range(n_repeat): + trunck = audio[(coeff * n) * K: (coeff * n + 1) * K] + if len(trunck) != K: + print('Audio too short, fail to detect watermark') + return 'Fail' + with torch.no_grad(): + signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0) + message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze() + bits.append(message_decoded_npy) + bits = np.stack(bits).reshape(-1, 8) + message = utils.bits_to_string(bits) + return message + diff --git a/indextts/s2mel/modules/openvoice/attentions.py b/indextts/s2mel/modules/openvoice/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..355115743ff9f5899adfc44c5075053c81066ba1 --- /dev/null +++ b/indextts/s2mel/modules/openvoice/attentions.py @@ -0,0 +1,465 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from . import commons +import logging + +logger = logging.getLogger(__name__) + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + isflow=True, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + # if isflow: + # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1) + # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) + # self.cond_layer = weight_norm(cond_layer, name='weight') + # self.gin_channels = 256 + self.cond_layer_idx = self.n_layers + if "gin_channels" in kwargs: + self.gin_channels = kwargs["gin_channels"] + if self.gin_channels != 0: + self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) + # vits2 says 3rd block, so idx is 2 by default + self.cond_layer_idx = ( + kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 + ) + # logging.debug(self.gin_channels, self.cond_layer_idx) + assert ( + self.cond_layer_idx < self.n_layers + ), "cond_layer_idx should be less than n_layers" + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, g=None): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + if i == self.cond_layer_idx and g is not None: + g = self.spk_emb_linear(g.transpose(1, 2)) + g = g.transpose(1, 2) + x = x + g + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + device=x.device, dtype=x.dtype + ) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert ( + t_s == t_t + ), "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys( + query / math.sqrt(self.k_channels), key_relative_embeddings + ) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert ( + t_s == t_t + ), "Local attention is only available for self-attention." + block_mask = ( + torch.ones_like(scores) + .triu(-self.block_length) + .tril(self.block_length) + ) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s + ) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings + ) + output = ( + output.transpose(2, 3).contiguous().view(b, d, t_t) + ) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[ + :, slice_start_position:slice_end_position + ] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad( + x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) + ) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # pad along column + x = F.pad( + x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) + ) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json b/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json new file mode 100644 index 0000000000000000000000000000000000000000..3e33566b0d976167bd5f15801ef7005d59143e2f --- /dev/null +++ b/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json @@ -0,0 +1,57 @@ +{ + "_version_": "v2", + "data": { + "sampling_rate": 22050, + "filter_length": 1024, + "hop_length": 256, + "win_length": 1024, + "n_speakers": 0 + }, + "model": { + "zero_g": true, + "inter_channels": 192, + "hidden_channels": 192, + "filter_channels": 768, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": 0.1, + "resblock": "1", + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "upsample_rates": [ + 8, + 8, + 2, + 2 + ], + "upsample_initial_channel": 512, + "upsample_kernel_sizes": [ + 16, + 16, + 4, + 4 + ], + "gin_channels": 256 + } +} \ No newline at end of file diff --git a/indextts/s2mel/modules/openvoice/commons.py b/indextts/s2mel/modules/openvoice/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..d3fa07f65b1681e1f469b04b2fe689b7c174eaaa --- /dev/null +++ b/indextts/s2mel/modules/openvoice/commons.py @@ -0,0 +1,160 @@ +import math +import torch +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + layer = pad_shape[::-1] + pad_shape = [item for sublist in layer for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + layer = pad_shape[::-1] + pad_shape = [item for sublist in layer for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/indextts/s2mel/modules/openvoice/mel_processing.py b/indextts/s2mel/modules/openvoice/mel_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..d47446966952228b6ca24569d3b0aef93c749ad9 --- /dev/null +++ b/indextts/s2mel/modules/openvoice/mel_processing.py @@ -0,0 +1,183 @@ +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + # if torch.min(y) < -1.1: + # print("min value is ", torch.min(y)) + # if torch.max(y) > 1.1: + # print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False): + # if torch.min(y) < -1.: + # print('min value is ', torch.min(y)) + # if torch.max(y) > 1.: + # print('max value is ', torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + '_' + str(y.device) + wnsize_dtype_device = str(win_size) + '_' + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + + # ******************** original ************************# + # y = y.squeeze(1) + # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + # center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + + # ******************** ConvSTFT ************************# + freq_cutoff = n_fft // 2 + 1 + fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft))) + forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1]) + forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float() + + import torch.nn.functional as F + + # if center: + # signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1) + assert center is False + + forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size) + spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1) + + + # ******************** Verification ************************# + spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + assert torch.allclose(spec1, spec2, atol=1e-4) + + spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( + dtype=spec.dtype, device=spec.device + ) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + return spec + + +def mel_spectrogram_torch( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( + dtype=y.dtype, device=y.device + ) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + + return spec \ No newline at end of file diff --git a/indextts/s2mel/modules/openvoice/models.py b/indextts/s2mel/modules/openvoice/models.py new file mode 100644 index 0000000000000000000000000000000000000000..1558bb3a3ac9fadc600948d04f97946502b16e7c --- /dev/null +++ b/indextts/s2mel/modules/openvoice/models.py @@ -0,0 +1,499 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from . import commons +from . import modules +from . import attentions + +from torch.nn import Conv1d, ConvTranspose1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +from .commons import init_weights, get_padding + + +class TextEncoder(nn.Module): + def __init__(self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class DurationPredictor(nn.Module): + def __init__( + self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 + ): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + +class StochasticDurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) + logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None, tau=1.0): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for layer in self.ups: + remove_weight_norm(layer) + for layer in self.resblocks: + layer.remove_weight_norm() + + +class ReferenceEncoder(nn.Module): + """ + inputs --- [N, Ty/r, n_mels*r] mels + outputs --- [N, ref_enc_gru_size] + """ + + def __init__(self, spec_channels, gin_channels=0, layernorm=True): + super().__init__() + self.spec_channels = spec_channels + ref_enc_filters = [32, 32, 64, 64, 128, 128] + K = len(ref_enc_filters) + filters = [1] + ref_enc_filters + convs = [ + weight_norm( + nn.Conv2d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + ) + for i in range(K) + ] + self.convs = nn.ModuleList(convs) + + out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) + self.gru = nn.GRU( + input_size=ref_enc_filters[-1] * out_channels, + hidden_size=256 // 2, + batch_first=True, + ) + self.proj = nn.Linear(128, gin_channels) + if layernorm: + self.layernorm = nn.LayerNorm(self.spec_channels) + else: + self.layernorm = None + + def forward(self, inputs, mask=None): + N = inputs.size(0) + + out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] + if self.layernorm is not None: + out = self.layernorm(out) + + for conv in self.convs: + out = conv(out) + # out = wn(out) + out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] + + out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] + T = out.size(1) + N = out.size(0) + out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] + + self.gru.flatten_parameters() + memory, out = self.gru(out) # out --- [1, N, 128] + + return self.proj(out.squeeze(0)) + + def calculate_channels(self, L, kernel_size, stride, pad, n_convs): + for i in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__( + self, + n_vocab, + spec_channels, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=256, + gin_channels=256, + zero_g=False, + **kwargs + ): + super().__init__() + + self.dec = Generator( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + self.n_speakers = n_speakers + if n_speakers == 0: + self.ref_enc = ReferenceEncoder(spec_channels, gin_channels) + else: + self.enc_p = TextEncoder(n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) + self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) + self.emb_g = nn.Embedding(n_speakers, gin_channels) + self.zero_g = zero_g + + def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., sdp_ratio=0.2, max_len=None): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio \ + + self.dp(x, x_mask, g=g) * (1 - sdp_ratio) + + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.dec((z * y_mask)[:,:,:max_len], g=g) + return o, attn, y_mask, (z, z_p, m_p, logs_p) + + def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0): + g_src = sid_src + g_tgt = sid_tgt + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt)) + return o_hat, y_mask, (z, z_p, z_hat) diff --git a/indextts/s2mel/modules/openvoice/modules.py b/indextts/s2mel/modules/openvoice/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..5af5d1152aa1790850f16155c323556c17bf9900 --- /dev/null +++ b/indextts/s2mel/modules/openvoice/modules.py @@ -0,0 +1,598 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from torch.nn import Conv1d +from torch.nn.utils import weight_norm, remove_weight_norm + +from . import commons +from .commons import init_weights, get_padding +from .transforms import piecewise_rational_quadratic_transform +from .attentions import Encoder + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append( + nn.Conv1d( + in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dilated and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__( + self, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + p_dropout=0, + ): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d( + gin_channels, 2 * hidden_channels * n_layers, 1 + ) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=p_dropout, + gin_channels=gin_channels, + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ConvFlow(nn.Module): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + n_layers, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) + self.proj = nn.Conv1d( + filter_channels, self.half_channels * (num_bins * 3 - 1), 1 + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( + self.filter_channels + ) + unnormalized_derivatives = h[..., 2 * self.num_bins :] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x + + +class TransformerCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + n_layers, + n_heads, + p_dropout=0, + filter_channels=0, + mean_only=False, + wn_sharing_parameter=None, + gin_channels=0, + ): + assert n_layers == 3, n_layers + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = ( + Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + isflow=True, + gin_channels=gin_channels, + ) + if wn_sharing_parameter is None + else wn_sharing_parameter + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x diff --git a/indextts/s2mel/modules/openvoice/openvoice_app.py b/indextts/s2mel/modules/openvoice/openvoice_app.py new file mode 100644 index 0000000000000000000000000000000000000000..744a7678cc2498987d90d639b63fd1c3a7bca5c7 --- /dev/null +++ b/indextts/s2mel/modules/openvoice/openvoice_app.py @@ -0,0 +1,275 @@ +import os +import torch +import argparse +import gradio as gr +from zipfile import ZipFile +import langid +from . import se_extractor +from .api import BaseSpeakerTTS, ToneColorConverter + +parser = argparse.ArgumentParser() +parser.add_argument("--share", action='store_true', default=False, help="make link public") +args = parser.parse_args() + +en_ckpt_base = 'checkpoints/base_speakers/EN' +zh_ckpt_base = 'checkpoints/base_speakers/ZH' +ckpt_converter = 'checkpoints/converter' +device = 'cuda' if torch.cuda.is_available() else 'cpu' +output_dir = 'outputs' +os.makedirs(output_dir, exist_ok=True) + +# load models +en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device) +en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth') +zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device) +zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth') +tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device) +tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth') + +# load speaker embeddings +en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device) +en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device) +zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device) + +# This online demo mainly supports English and Chinese +supported_languages = ['zh', 'en'] + +def predict(prompt, style, audio_file_pth, agree): + # initialize a empty info + text_hint = '' + # agree with the terms + if agree == False: + text_hint += '[ERROR] Please accept the Terms & Condition!\n' + gr.Warning("Please accept the Terms & Condition!") + return ( + text_hint, + None, + None, + ) + + # first detect the input language + language_predicted = langid.classify(prompt)[0].strip() + print(f"Detected language:{language_predicted}") + + if language_predicted not in supported_languages: + text_hint += f"[ERROR] The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}\n" + gr.Warning( + f"The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}" + ) + + return ( + text_hint, + None, + None, + ) + + if language_predicted == "zh": + tts_model = zh_base_speaker_tts + source_se = zh_source_se + language = 'Chinese' + if style not in ['default']: + text_hint += f"[ERROR] The style {style} is not supported for Chinese, which should be in ['default']\n" + gr.Warning(f"The style {style} is not supported for Chinese, which should be in ['default']") + return ( + text_hint, + None, + None, + ) + + else: + tts_model = en_base_speaker_tts + if style == 'default': + source_se = en_source_default_se + else: + source_se = en_source_style_se + language = 'English' + if style not in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']: + text_hint += f"[ERROR] The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']\n" + gr.Warning(f"The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']") + return ( + text_hint, + None, + None, + ) + + speaker_wav = audio_file_pth + + if len(prompt) < 2: + text_hint += f"[ERROR] Please give a longer prompt text \n" + gr.Warning("Please give a longer prompt text") + return ( + text_hint, + None, + None, + ) + if len(prompt) > 200: + text_hint += f"[ERROR] Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo and try for your usage \n" + gr.Warning( + "Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo for your usage" + ) + return ( + text_hint, + None, + None, + ) + + # note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference + try: + target_se, audio_name = se_extractor.get_se(speaker_wav, tone_color_converter, target_dir='processed', vad=True) + except Exception as e: + text_hint += f"[ERROR] Get target tone color error {str(e)} \n" + gr.Warning( + "[ERROR] Get target tone color error {str(e)} \n" + ) + return ( + text_hint, + None, + None, + ) + + src_path = f'{output_dir}/tmp.wav' + tts_model.tts(prompt, src_path, speaker=style, language=language) + + save_path = f'{output_dir}/output.wav' + # Run the tone color converter + encode_message = "@MyShell" + tone_color_converter.convert( + audio_src_path=src_path, + src_se=source_se, + tgt_se=target_se, + output_path=save_path, + message=encode_message) + + text_hint += f'''Get response successfully \n''' + + return ( + text_hint, + save_path, + speaker_wav, + ) + + + +title = "MyShell OpenVoice" + +description = """ +We introduce OpenVoice, a versatile instant voice cloning approach that requires only a short audio clip from the reference speaker to replicate their voice and generate speech in multiple languages. OpenVoice enables granular control over voice styles, including emotion, accent, rhythm, pauses, and intonation, in addition to replicating the tone color of the reference speaker. OpenVoice also achieves zero-shot cross-lingual voice cloning for languages not included in the massive-speaker training set. +""" + +markdown_table = """ +

+ +| | | | +| :-----------: | :-----------: | :-----------: | +| **OpenSource Repo** | **Project Page** | **Join the Community** | +|
| [OpenVoice](https://research.myshell.ai/open-voice) | [![Discord](https://img.shields.io/discord/1122227993805336617?color=%239B59B6&label=%20Discord%20)](https://discord.gg/myshell) | + +
+""" + +markdown_table_v2 = """ +
+ +| | | | | +| :-----------: | :-----------: | :-----------: | :-----------: | +| **OpenSource Repo** |
| **Project Page** | [OpenVoice](https://research.myshell.ai/open-voice) | + +| | | +| :-----------: | :-----------: | +**Join the Community** | [![Discord](https://img.shields.io/discord/1122227993805336617?color=%239B59B6&label=%20Discord%20)](https://discord.gg/myshell) | + +
+""" +content = """ +
+ If the generated voice does not sound like the reference voice, please refer to this QnA. For multi-lingual & cross-lingual examples, please refer to this jupyter notebook. + This online demo mainly supports English. The default style also supports Chinese. But OpenVoice can adapt to any other language as long as a base speaker is provided. +
+""" +wrapped_markdown_content = f"
{content}
" + + +examples = [ + [ + "今天天气真好,我们一起出去吃饭吧。", + 'default', + "resources/demo_speaker1.mp3", + True, + ],[ + "This audio is generated by open voice with a half-performance model.", + 'whispering', + "resources/demo_speaker2.mp3", + True, + ], + [ + "He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.", + 'sad', + "resources/demo_speaker0.mp3", + True, + ], +] + +with gr.Blocks(analytics_enabled=False) as demo: + + with gr.Row(): + with gr.Column(): + with gr.Row(): + gr.Markdown( + """ + ## + """ + ) + with gr.Row(): + gr.Markdown(markdown_table_v2) + with gr.Row(): + gr.Markdown(description) + with gr.Column(): + gr.Video('https://github.com/myshell-ai/OpenVoice/assets/40556743/3cba936f-82bf-476c-9e52-09f0f417bb2f', autoplay=True) + + with gr.Row(): + gr.HTML(wrapped_markdown_content) + + with gr.Row(): + with gr.Column(): + input_text_gr = gr.Textbox( + label="Text Prompt", + info="One or two sentences at a time is better. Up to 200 text characters.", + value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.", + ) + style_gr = gr.Dropdown( + label="Style", + info="Select a style of output audio for the synthesised speech. (Chinese only support 'default' now)", + choices=['default', 'whispering', 'cheerful', 'terrified', 'angry', 'sad', 'friendly'], + max_choices=1, + value="default", + ) + ref_gr = gr.Audio( + label="Reference Audio", + info="Click on the ✎ button to upload your own target speaker audio", + type="filepath", + value="resources/demo_speaker2.mp3", + ) + tos_gr = gr.Checkbox( + label="Agree", + value=False, + info="I agree to the terms of the cc-by-nc-4.0 license-: https://github.com/myshell-ai/OpenVoice/blob/main/LICENSE", + ) + + tts_button = gr.Button("Send", elem_id="send-btn", visible=True) + + + with gr.Column(): + out_text_gr = gr.Text(label="Info") + audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True) + ref_audio_gr = gr.Audio(label="Reference Audio Used") + + gr.Examples(examples, + label="Examples", + inputs=[input_text_gr, style_gr, ref_gr, tos_gr], + outputs=[out_text_gr, audio_gr, ref_audio_gr], + fn=predict, + cache_examples=False,) + tts_button.click(predict, [input_text_gr, style_gr, ref_gr, tos_gr], outputs=[out_text_gr, audio_gr, ref_audio_gr]) + +demo.queue() +demo.launch(debug=True, show_api=True, share=args.share) diff --git a/indextts/s2mel/modules/openvoice/se_extractor.py b/indextts/s2mel/modules/openvoice/se_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..d08717984508c893dd88eeccc02b325ca4ea0a6e --- /dev/null +++ b/indextts/s2mel/modules/openvoice/se_extractor.py @@ -0,0 +1,153 @@ +import os +import glob +import torch +import hashlib +import librosa +import base64 +from glob import glob +import numpy as np +from pydub import AudioSegment +from faster_whisper import WhisperModel +import hashlib +import base64 +import librosa +# from whisper_timestamped.transcribe import get_audio_tensor, get_vad_segments + +model_size = "medium" +# Run on GPU with FP16 +model = None +def split_audio_whisper(audio_path, audio_name, target_dir='processed'): + global model + if model is None: + model = WhisperModel(model_size, device="cuda", compute_type="float16") + audio = AudioSegment.from_file(audio_path) + max_len = len(audio) + + target_folder = os.path.join(target_dir, audio_name) + + segments, info = model.transcribe(audio_path, beam_size=5, word_timestamps=True) + segments = list(segments) + + # create directory + os.makedirs(target_folder, exist_ok=True) + wavs_folder = os.path.join(target_folder, 'wavs') + os.makedirs(wavs_folder, exist_ok=True) + + # segments + s_ind = 0 + start_time = None + + for k, w in enumerate(segments): + # process with the time + if k == 0: + start_time = max(0, w.start) + + end_time = w.end + + # calculate confidence + if len(w.words) > 0: + confidence = sum([s.probability for s in w.words]) / len(w.words) + else: + confidence = 0. + # clean text + text = w.text.replace('...', '') + + # left 0.08s for each audios + audio_seg = audio[int( start_time * 1000) : min(max_len, int(end_time * 1000) + 80)] + + # segment file name + fname = f"{audio_name}_seg{s_ind}.wav" + + # filter out the segment shorter than 1.5s and longer than 20s + save = audio_seg.duration_seconds > 1.5 and \ + audio_seg.duration_seconds < 20. and \ + len(text) >= 2 and len(text) < 200 + + if save: + output_file = os.path.join(wavs_folder, fname) + audio_seg.export(output_file, format='wav') + + if k < len(segments) - 1: + start_time = max(0, segments[k+1].start - 0.08) + + s_ind = s_ind + 1 + return wavs_folder + + +def split_audio_vad(audio_path, audio_name, target_dir, split_seconds=10.0): + SAMPLE_RATE = 16000 + audio_vad = get_audio_tensor(audio_path) + segments = get_vad_segments( + audio_vad, + output_sample=True, + min_speech_duration=0.1, + min_silence_duration=1, + method="silero", + ) + segments = [(seg["start"], seg["end"]) for seg in segments] + segments = [(float(s) / SAMPLE_RATE, float(e) / SAMPLE_RATE) for s,e in segments] + print(segments) + audio_active = AudioSegment.silent(duration=0) + audio = AudioSegment.from_file(audio_path) + + for start_time, end_time in segments: + audio_active += audio[int( start_time * 1000) : int(end_time * 1000)] + + audio_dur = audio_active.duration_seconds + print(f'after vad: dur = {audio_dur}') + target_folder = os.path.join(target_dir, audio_name) + wavs_folder = os.path.join(target_folder, 'wavs') + os.makedirs(wavs_folder, exist_ok=True) + start_time = 0. + count = 0 + num_splits = int(np.round(audio_dur / split_seconds)) + assert num_splits > 0, 'input audio is too short' + interval = audio_dur / num_splits + + for i in range(num_splits): + end_time = min(start_time + interval, audio_dur) + if i == num_splits - 1: + end_time = audio_dur + output_file = f"{wavs_folder}/{audio_name}_seg{count}.wav" + audio_seg = audio_active[int(start_time * 1000): int(end_time * 1000)] + audio_seg.export(output_file, format='wav') + start_time = end_time + count += 1 + return wavs_folder + +def hash_numpy_array(audio_path): + array, _ = librosa.load(audio_path, sr=None, mono=True) + # Convert the array to bytes + array_bytes = array.tobytes() + # Calculate the hash of the array bytes + hash_object = hashlib.sha256(array_bytes) + hash_value = hash_object.digest() + # Convert the hash value to base64 + base64_value = base64.b64encode(hash_value) + return base64_value.decode('utf-8')[:16].replace('/', '_^') + +def get_se(audio_path, vc_model, target_dir='processed', vad=True): + device = vc_model.device + version = vc_model.version + print("OpenVoice version:", version) + + audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{version}_{hash_numpy_array(audio_path)}" + se_path = os.path.join(target_dir, audio_name, 'se.pth') + + # if os.path.isfile(se_path): + # se = torch.load(se_path).to(device) + # return se, audio_name + # if os.path.isdir(audio_path): + # wavs_folder = audio_path + + # if vad: + # wavs_folder = split_audio_vad(audio_path, target_dir=target_dir, audio_name=audio_name) + # else: + # wavs_folder = split_audio_whisper(audio_path, target_dir=target_dir, audio_name=audio_name) + + # audio_segs = glob(f'{wavs_folder}/*.wav') + # if len(audio_segs) == 0: + # raise NotImplementedError('No audio segments found!') + + return vc_model.extract_se([audio_path], se_save_path=se_path), audio_name + diff --git a/indextts/s2mel/modules/openvoice/transforms.py b/indextts/s2mel/modules/openvoice/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..a11f799e023864ff7082c1f49c0cc18351a13b47 --- /dev/null +++ b/indextts/s2mel/modules/openvoice/transforms.py @@ -0,0 +1,209 @@ +import torch +from torch.nn import functional as F + +import numpy as np + + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * ( + input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta + ) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/indextts/s2mel/modules/openvoice/utils.py b/indextts/s2mel/modules/openvoice/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4e80909d6a03976400322cc0219d1871e9f84bfa --- /dev/null +++ b/indextts/s2mel/modules/openvoice/utils.py @@ -0,0 +1,194 @@ +import re +import json +import numpy as np + + +def get_hparams_from_file(config_path): + with open(config_path, "r", encoding="utf-8") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + return hparams + +class HParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + + +def string_to_bits(string, pad_len=8): + # Convert each character to its ASCII value + ascii_values = [ord(char) for char in string] + + # Convert ASCII values to binary representation + binary_values = [bin(value)[2:].zfill(8) for value in ascii_values] + + # Convert binary strings to integer arrays + bit_arrays = [[int(bit) for bit in binary] for binary in binary_values] + + # Convert list of arrays to NumPy array + numpy_array = np.array(bit_arrays) + numpy_array_full = np.zeros((pad_len, 8), dtype=numpy_array.dtype) + numpy_array_full[:, 2] = 1 + max_len = min(pad_len, len(numpy_array)) + numpy_array_full[:max_len] = numpy_array[:max_len] + return numpy_array_full + + +def bits_to_string(bits_array): + # Convert each row of the array to a binary string + binary_values = [''.join(str(bit) for bit in row) for row in bits_array] + + # Convert binary strings to ASCII values + ascii_values = [int(binary, 2) for binary in binary_values] + + # Convert ASCII values to characters + output_string = ''.join(chr(value) for value in ascii_values) + + return output_string + + +def split_sentence(text, min_len=10, language_str='[EN]'): + if language_str in ['EN']: + sentences = split_sentences_latin(text, min_len=min_len) + else: + sentences = split_sentences_zh(text, min_len=min_len) + return sentences + +def split_sentences_latin(text, min_len=10): + """Split Long sentences into list of short ones + + Args: + str: Input sentences. + + Returns: + List[str]: list of output sentences. + """ + # deal with dirty sentences + text = re.sub('[。!?;]', '.', text) + text = re.sub('[,]', ',', text) + text = re.sub('[“”]', '"', text) + text = re.sub('[‘’]', "'", text) + text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text) + text = re.sub('[\n\t ]+', ' ', text) + text = re.sub('([,.!?;])', r'\1 $#!', text) + # split + sentences = [s.strip() for s in text.split('$#!')] + if len(sentences[-1]) == 0: del sentences[-1] + + new_sentences = [] + new_sent = [] + count_len = 0 + for ind, sent in enumerate(sentences): + # print(sent) + new_sent.append(sent) + count_len += len(sent.split(" ")) + if count_len > min_len or ind == len(sentences) - 1: + count_len = 0 + new_sentences.append(' '.join(new_sent)) + new_sent = [] + return merge_short_sentences_latin(new_sentences) + + +def merge_short_sentences_latin(sens): + """Avoid short sentences by merging them with the following sentence. + + Args: + List[str]: list of input sentences. + + Returns: + List[str]: list of output sentences. + """ + sens_out = [] + for s in sens: + # If the previous sentence is too short, merge them with + # the current sentence. + if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2: + sens_out[-1] = sens_out[-1] + " " + s + else: + sens_out.append(s) + try: + if len(sens_out[-1].split(" ")) <= 2: + sens_out[-2] = sens_out[-2] + " " + sens_out[-1] + sens_out.pop(-1) + except: + pass + return sens_out + +def split_sentences_zh(text, min_len=10): + text = re.sub('[。!?;]', '.', text) + text = re.sub('[,]', ',', text) + # 将文本中的换行符、空格和制表符替换为空格 + text = re.sub('[\n\t ]+', ' ', text) + # 在标点符号后添加一个空格 + text = re.sub('([,.!?;])', r'\1 $#!', text) + # 分隔句子并去除前后空格 + # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)] + sentences = [s.strip() for s in text.split('$#!')] + if len(sentences[-1]) == 0: del sentences[-1] + + new_sentences = [] + new_sent = [] + count_len = 0 + for ind, sent in enumerate(sentences): + new_sent.append(sent) + count_len += len(sent) + if count_len > min_len or ind == len(sentences) - 1: + count_len = 0 + new_sentences.append(' '.join(new_sent)) + new_sent = [] + return merge_short_sentences_zh(new_sentences) + + +def merge_short_sentences_zh(sens): + # return sens + """Avoid short sentences by merging them with the following sentence. + + Args: + List[str]: list of input sentences. + + Returns: + List[str]: list of output sentences. + """ + sens_out = [] + for s in sens: + # If the previous sentense is too short, merge them with + # the current sentence. + if len(sens_out) > 0 and len(sens_out[-1]) <= 2: + sens_out[-1] = sens_out[-1] + " " + s + else: + sens_out.append(s) + try: + if len(sens_out[-1]) <= 2: + sens_out[-2] = sens_out[-2] + " " + sens_out[-1] + sens_out.pop(-1) + except: + pass + return sens_out \ No newline at end of file diff --git a/indextts/s2mel/modules/quantize.py b/indextts/s2mel/modules/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..c81603b74d8e36ee9c70b7c123996d04a342ddda --- /dev/null +++ b/indextts/s2mel/modules/quantize.py @@ -0,0 +1,229 @@ +from dac.nn.quantize import ResidualVectorQuantize +from torch import nn +from modules.wavenet import WN +import torch +import torchaudio +import torchaudio.functional as audio_F +import numpy as np +from .alias_free_torch import * +from torch.nn.utils import weight_norm +from torch import nn, sin, pow +from einops.layers.torch import Rearrange +from dac.model.encodec import SConv1d + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False + ): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta := x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + return x + self.block(x) + +class CNNLSTM(nn.Module): + def __init__(self, indim, outdim, head, global_pred=False): + super().__init__() + self.global_pred = global_pred + self.model = nn.Sequential( + ResidualUnit(indim, dilation=1), + ResidualUnit(indim, dilation=2), + ResidualUnit(indim, dilation=3), + Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)), + Rearrange("b c t -> b t c"), + ) + self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)]) + + def forward(self, x): + # x: [B, C, T] + x = self.model(x) + if self.global_pred: + x = torch.mean(x, dim=1, keepdim=False) + outs = [head(x) for head in self.heads] + return outs + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) +class FAquantizer(nn.Module): + def __init__(self, in_dim=1024, + n_p_codebooks=1, + n_c_codebooks=2, + n_t_codebooks=2, + n_r_codebooks=3, + codebook_size=1024, + codebook_dim=8, + quantizer_dropout=0.5, + causal=False, + separate_prosody_encoder=False, + timbre_norm=False,): + super(FAquantizer, self).__init__() + conv1d_type = SConv1d# if causal else nn.Conv1d + self.prosody_quantizer = ResidualVectorQuantize( + input_dim=in_dim, + n_codebooks=n_p_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.content_quantizer = ResidualVectorQuantize( + input_dim=in_dim, + n_codebooks=n_c_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.residual_quantizer = ResidualVectorQuantize( + input_dim=in_dim, + n_codebooks=n_r_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.melspec_linear = conv1d_type(in_channels=20, out_channels=256, kernel_size=1, causal=causal) + self.melspec_encoder = WN(hidden_channels=256, kernel_size=5, dilation_rate=1, n_layers=8, gin_channels=0, p_dropout=0.2, causal=causal) + self.melspec_linear2 = conv1d_type(in_channels=256, out_channels=1024, kernel_size=1, causal=causal) + + self.prob_random_mask_residual = 0.75 + + SPECT_PARAMS = { + "n_fft": 2048, + "win_length": 1200, + "hop_length": 300, + } + MEL_PARAMS = { + "n_mels": 80, + } + + self.to_mel = torchaudio.transforms.MelSpectrogram( + n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS + ) + self.mel_mean, self.mel_std = -4, 4 + self.frame_rate = 24000 / 300 + self.hop_length = 300 + + def preprocess(self, wave_tensor, n_bins=20): + mel_tensor = self.to_mel(wave_tensor.squeeze(1)) + mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std + return mel_tensor[:, :n_bins, :int(wave_tensor.size(-1) / self.hop_length)] + + def forward(self, x, wave_segments): + outs = 0 + prosody_feature = self.preprocess(wave_segments) + + f0_input = prosody_feature # (B, T, 20) + f0_input = self.melspec_linear(f0_input) + f0_input = self.melspec_encoder(f0_input, torch.ones(f0_input.shape[0], 1, f0_input.shape[2]).to( + f0_input.device).bool()) + f0_input = self.melspec_linear2(f0_input) + + common_min_size = min(f0_input.size(2), x.size(2)) + f0_input = f0_input[:, :, :common_min_size] + + x = x[:, :, :common_min_size] + + z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer( + f0_input, 1 + ) + outs += z_p.detach() + + z_c, codes_c, latents_c, commitment_loss_c, codebook_loss_c = self.content_quantizer( + x, 2 + ) + outs += z_c.detach() + + residual_feature = x - z_p.detach() - z_c.detach() + + z_r, codes_r, latents_r, commitment_loss_r, codebook_loss_r = self.residual_quantizer( + residual_feature, 3 + ) + + quantized = [z_p, z_c, z_r] + codes = [codes_p, codes_c, codes_r] + + return quantized, codes \ No newline at end of file diff --git a/indextts/s2mel/modules/rmvpe.py b/indextts/s2mel/modules/rmvpe.py new file mode 100644 index 0000000000000000000000000000000000000000..895c580a22753fdcb1e444520f0123dfb80d5826 --- /dev/null +++ b/indextts/s2mel/modules/rmvpe.py @@ -0,0 +1,631 @@ +from io import BytesIO +import os +from typing import List, Optional, Tuple +import numpy as np +import torch + +import torch.nn as nn +import torch.nn.functional as F +from librosa.util import normalize, pad_center, tiny +from scipy.signal import get_window + +import logging + +logger = logging.getLogger(__name__) + + +class STFT(torch.nn.Module): + def __init__( + self, filter_length=1024, hop_length=512, win_length=None, window="hann" + ): + """ + This module implements an STFT using 1D convolution and 1D transpose convolutions. + This is a bit tricky so there are some cases that probably won't work as working + out the same sizes before and after in all overlap add setups is tough. Right now, + this code should work with hop lengths that are half the filter length (50% overlap + between frames). + + Keyword Arguments: + filter_length {int} -- Length of filters used (default: {1024}) + hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512}) + win_length {[type]} -- Length of the window function applied to each frame (if not specified, it + equals the filter length). (default: {None}) + window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris) + (default: {'hann'}) + """ + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length if win_length else filter_length + self.window = window + self.forward_transform = None + self.pad_amount = int(self.filter_length / 2) + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) + forward_basis = torch.FloatTensor(fourier_basis) + inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis)) + + assert filter_length >= self.win_length + # get window and zero center pad it to filter_length + fft_window = get_window(window, self.win_length, fftbins=True) + fft_window = pad_center(fft_window, size=filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis = (inverse_basis.T * fft_window).T + + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) + self.register_buffer("fft_window", fft_window.float()) + + def transform(self, input_data, return_phase=False): + """Take input data (audio) to STFT domain. + + Arguments: + input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples) + + Returns: + magnitude {tensor} -- Magnitude of STFT with shape (num_batch, + num_frequencies, num_frames) + phase {tensor} -- Phase of STFT with shape (num_batch, + num_frequencies, num_frames) + """ + input_data = F.pad( + input_data, + (self.pad_amount, self.pad_amount), + mode="reflect", + ) + forward_transform = input_data.unfold( + 1, self.filter_length, self.hop_length + ).permute(0, 2, 1) + forward_transform = torch.matmul(self.forward_basis, forward_transform) + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + magnitude = torch.sqrt(real_part**2 + imag_part**2) + if return_phase: + phase = torch.atan2(imag_part.data, real_part.data) + return magnitude, phase + else: + return magnitude + + def inverse(self, magnitude, phase): + """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced + by the ```transform``` function. + + Arguments: + magnitude {tensor} -- Magnitude of STFT with shape (num_batch, + num_frequencies, num_frames) + phase {tensor} -- Phase of STFT with shape (num_batch, + num_frequencies, num_frames) + + Returns: + inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of + shape (num_batch, num_samples) + """ + cat = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + fold = torch.nn.Fold( + output_size=(1, (cat.size(-1) - 1) * self.hop_length + self.filter_length), + kernel_size=(1, self.filter_length), + stride=(1, self.hop_length), + ) + inverse_transform = torch.matmul(self.inverse_basis, cat) + inverse_transform = fold(inverse_transform)[ + :, 0, 0, self.pad_amount : -self.pad_amount + ] + window_square_sum = ( + self.fft_window.pow(2).repeat(cat.size(-1), 1).T.unsqueeze(0) + ) + window_square_sum = fold(window_square_sum)[ + :, 0, 0, self.pad_amount : -self.pad_amount + ] + inverse_transform /= window_square_sum + return inverse_transform + + def forward(self, input_data): + """Take input data (audio) to STFT domain and then back to audio. + + Arguments: + input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples) + + Returns: + reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of + shape (num_batch, num_samples) + """ + self.magnitude, self.phase = self.transform(input_data, return_phase=True) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +from time import time as ttime + + +class BiGRU(nn.Module): + def __init__(self, input_features, hidden_features, num_layers): + super(BiGRU, self).__init__() + self.gru = nn.GRU( + input_features, + hidden_features, + num_layers=num_layers, + batch_first=True, + bidirectional=True, + ) + + def forward(self, x): + return self.gru(x)[0] + + +class ConvBlockRes(nn.Module): + def __init__(self, in_channels, out_channels, momentum=0.01): + super(ConvBlockRes, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + # self.shortcut:Optional[nn.Module] = None + if in_channels != out_channels: + self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) + + def forward(self, x: torch.Tensor): + if not hasattr(self, "shortcut"): + return self.conv(x) + x + else: + return self.conv(x) + self.shortcut(x) + + +class Encoder(nn.Module): + def __init__( + self, + in_channels, + in_size, + n_encoders, + kernel_size, + n_blocks, + out_channels=16, + momentum=0.01, + ): + super(Encoder, self).__init__() + self.n_encoders = n_encoders + self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) + self.layers = nn.ModuleList() + self.latent_channels = [] + for i in range(self.n_encoders): + self.layers.append( + ResEncoderBlock( + in_channels, out_channels, kernel_size, n_blocks, momentum=momentum + ) + ) + self.latent_channels.append([out_channels, in_size]) + in_channels = out_channels + out_channels *= 2 + in_size //= 2 + self.out_size = in_size + self.out_channel = out_channels + + def forward(self, x: torch.Tensor): + concat_tensors: List[torch.Tensor] = [] + x = self.bn(x) + for i, layer in enumerate(self.layers): + t, x = layer(x) + concat_tensors.append(t) + return x, concat_tensors + + +class ResEncoderBlock(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01 + ): + super(ResEncoderBlock, self).__init__() + self.n_blocks = n_blocks + self.conv = nn.ModuleList() + self.conv.append(ConvBlockRes(in_channels, out_channels, momentum)) + for i in range(n_blocks - 1): + self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) + self.kernel_size = kernel_size + if self.kernel_size is not None: + self.pool = nn.AvgPool2d(kernel_size=kernel_size) + + def forward(self, x): + for i, conv in enumerate(self.conv): + x = conv(x) + if self.kernel_size is not None: + return x, self.pool(x) + else: + return x + + +class Intermediate(nn.Module): # + def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): + super(Intermediate, self).__init__() + self.n_inters = n_inters + self.layers = nn.ModuleList() + self.layers.append( + ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum) + ) + for i in range(self.n_inters - 1): + self.layers.append( + ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum) + ) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = layer(x) + return x + + +class ResDecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): + super(ResDecoderBlock, self).__init__() + out_padding = (0, 1) if stride == (1, 2) else (1, 1) + self.n_blocks = n_blocks + self.conv1 = nn.Sequential( + nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=stride, + padding=(1, 1), + output_padding=out_padding, + bias=False, + ), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + self.conv2 = nn.ModuleList() + self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum)) + for i in range(n_blocks - 1): + self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum)) + + def forward(self, x, concat_tensor): + x = self.conv1(x) + x = torch.cat((x, concat_tensor), dim=1) + for i, conv2 in enumerate(self.conv2): + x = conv2(x) + return x + + +class Decoder(nn.Module): + def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): + super(Decoder, self).__init__() + self.layers = nn.ModuleList() + self.n_decoders = n_decoders + for i in range(self.n_decoders): + out_channels = in_channels // 2 + self.layers.append( + ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum) + ) + in_channels = out_channels + + def forward(self, x: torch.Tensor, concat_tensors: List[torch.Tensor]): + for i, layer in enumerate(self.layers): + x = layer(x, concat_tensors[-1 - i]) + return x + + +class DeepUnet(nn.Module): + def __init__( + self, + kernel_size, + n_blocks, + en_de_layers=5, + inter_layers=4, + in_channels=1, + en_out_channels=16, + ): + super(DeepUnet, self).__init__() + self.encoder = Encoder( + in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels + ) + self.intermediate = Intermediate( + self.encoder.out_channel // 2, + self.encoder.out_channel, + inter_layers, + n_blocks, + ) + self.decoder = Decoder( + self.encoder.out_channel, en_de_layers, kernel_size, n_blocks + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, concat_tensors = self.encoder(x) + x = self.intermediate(x) + x = self.decoder(x, concat_tensors) + return x + + +class E2E(nn.Module): + def __init__( + self, + n_blocks, + n_gru, + kernel_size, + en_de_layers=5, + inter_layers=4, + in_channels=1, + en_out_channels=16, + ): + super(E2E, self).__init__() + self.unet = DeepUnet( + kernel_size, + n_blocks, + en_de_layers, + inter_layers, + in_channels, + en_out_channels, + ) + self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) + if n_gru: + self.fc = nn.Sequential( + BiGRU(3 * 128, 256, n_gru), + nn.Linear(512, 360), + nn.Dropout(0.25), + nn.Sigmoid(), + ) + else: + self.fc = nn.Sequential( + nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid() + ) + + def forward(self, mel): + # print(mel.shape) + mel = mel.transpose(-1, -2).unsqueeze(1) + x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) + x = self.fc(x) + # print(x.shape) + return x + + +from librosa.filters import mel + + +class MelSpectrogram(torch.nn.Module): + def __init__( + self, + is_half, + n_mel_channels, + sampling_rate, + win_length, + hop_length, + n_fft=None, + mel_fmin=0, + mel_fmax=None, + clamp=1e-5, + ): + super().__init__() + n_fft = win_length if n_fft is None else n_fft + self.hann_window = {} + mel_basis = mel( + sr=sampling_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=mel_fmin, + fmax=mel_fmax, + htk=True, + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + self.n_fft = win_length if n_fft is None else n_fft + self.hop_length = hop_length + self.win_length = win_length + self.sampling_rate = sampling_rate + self.n_mel_channels = n_mel_channels + self.clamp = clamp + self.is_half = is_half + + def forward(self, audio, keyshift=0, speed=1, center=True): + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(self.n_fft * factor)) + win_length_new = int(np.round(self.win_length * factor)) + hop_length_new = int(np.round(self.hop_length * speed)) + keyshift_key = str(keyshift) + "_" + str(audio.device) + if keyshift_key not in self.hann_window: + self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to( + audio.device + ) + if "privateuseone" in str(audio.device): + if not hasattr(self, "stft"): + self.stft = STFT( + filter_length=n_fft_new, + hop_length=hop_length_new, + win_length=win_length_new, + window="hann", + ).to(audio.device) + magnitude = self.stft.transform(audio) + else: + fft = torch.stft( + audio, + n_fft=n_fft_new, + hop_length=hop_length_new, + win_length=win_length_new, + window=self.hann_window[keyshift_key], + center=center, + return_complex=True, + ) + magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) + if keyshift != 0: + size = self.n_fft // 2 + 1 + resize = magnitude.size(1) + if resize < size: + magnitude = F.pad(magnitude, (0, 0, 0, size - resize)) + magnitude = magnitude[:, :size, :] * self.win_length / win_length_new + mel_output = torch.matmul(self.mel_basis, magnitude) + if self.is_half == True: + mel_output = mel_output.half() + log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) + return log_mel_spec + + +class RMVPE: + def __init__(self, model_path: str, is_half, device=None, use_jit=False): + self.resample_kernel = {} + self.resample_kernel = {} + self.is_half = is_half + if device is None: + device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.device = device + self.mel_extractor = MelSpectrogram( + is_half, 128, 16000, 1024, 160, None, 30, 8000 + ).to(device) + if "privateuseone" in str(device): + import onnxruntime as ort + + ort_session = ort.InferenceSession( + "%s/rmvpe.onnx" % os.environ["rmvpe_root"], + providers=["DmlExecutionProvider"], + ) + self.model = ort_session + else: + if str(self.device) == "cuda": + self.device = torch.device("cuda:0") + + def get_default_model(): + model = E2E(4, 1, (2, 2)) + ckpt = torch.load(model_path, map_location="cpu") + model.load_state_dict(ckpt) + model.eval() + if is_half: + model = model.half() + else: + model = model.float() + return model + + self.model = get_default_model() + + self.model = self.model.to(device) + cents_mapping = 20 * np.arange(360) + 1997.3794084376191 + self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368 + + def mel2hidden(self, mel): + with torch.no_grad(): + n_frames = mel.shape[-1] + n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames + if n_pad > 0: + mel = F.pad(mel, (0, n_pad), mode="constant") + if "privateuseone" in str(self.device): + onnx_input_name = self.model.get_inputs()[0].name + onnx_outputs_names = self.model.get_outputs()[0].name + hidden = self.model.run( + [onnx_outputs_names], + input_feed={onnx_input_name: mel.cpu().numpy()}, + )[0] + else: + mel = mel.half() if self.is_half else mel.float() + hidden = self.model(mel) + return hidden[:, :n_frames] + + def decode(self, hidden, thred=0.03): + cents_pred = self.to_local_average_cents(hidden, thred=thred) + f0 = 10 * (2 ** (cents_pred / 1200)) + f0[f0 == 10] = 0 + # f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred]) + return f0 + + def infer_from_audio(self, audio, thred=0.03): + # torch.cuda.synchronize() + # t0 = ttime() + if not torch.is_tensor(audio): + audio = torch.from_numpy(audio) + mel = self.mel_extractor( + audio.float().to(self.device).unsqueeze(0), center=True + ) + # print(123123123,mel.device.type) + # torch.cuda.synchronize() + # t1 = ttime() + hidden = self.mel2hidden(mel) + # torch.cuda.synchronize() + # t2 = ttime() + # print(234234,hidden.device.type) + if "privateuseone" not in str(self.device): + hidden = hidden.squeeze(0).cpu().numpy() + else: + hidden = hidden[0] + if self.is_half == True: + hidden = hidden.astype("float32") + + f0 = self.decode(hidden, thred=thred) + # torch.cuda.synchronize() + # t3 = ttime() + # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0)) + return f0 + def infer_from_audio_batch(self, audio, thred=0.03): + # torch.cuda.synchronize() + # t0 = ttime() + if not torch.is_tensor(audio): + audio = torch.from_numpy(audio) + mel = self.mel_extractor( + audio.float().to(self.device), center=True + ) + # print(123123123,mel.device.type) + # torch.cuda.synchronize() + # t1 = ttime() + hidden = self.mel2hidden(mel) + # torch.cuda.synchronize() + # t2 = ttime() + # print(234234,hidden.device.type) + if "privateuseone" not in str(self.device): + hidden = hidden.cpu().numpy() + else: + pass + if self.is_half == True: + hidden = hidden.astype("float32") + + f0s = [] + for bib in range(hidden.shape[0]): + f0s.append(self.decode(hidden[bib], thred=thred)) + f0s = np.stack(f0s) + f0s = torch.from_numpy(f0s).to(self.device) + # torch.cuda.synchronize() + # t3 = ttime() + # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0)) + return f0s + + def to_local_average_cents(self, salience, thred=0.05): + # t0 = ttime() + center = np.argmax(salience, axis=1) # 帧长#index + salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368 + # t1 = ttime() + center += 4 + todo_salience = [] + todo_cents_mapping = [] + starts = center - 4 + ends = center + 5 + for idx in range(salience.shape[0]): + todo_salience.append(salience[:, starts[idx] : ends[idx]][idx]) + todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]]) + # t2 = ttime() + todo_salience = np.array(todo_salience) # 帧长,9 + todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9 + product_sum = np.sum(todo_salience * todo_cents_mapping, 1) + weight_sum = np.sum(todo_salience, 1) # 帧长 + devided = product_sum / weight_sum # 帧长 + # t3 = ttime() + maxx = np.max(salience, axis=1) # 帧长 + devided[maxx <= thred] = 0 + # t4 = ttime() + # print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) + return devided diff --git a/indextts/s2mel/modules/vocos/__init__.py b/indextts/s2mel/modules/vocos/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a11ba951363004a596f88f8bf9c32190a68782c --- /dev/null +++ b/indextts/s2mel/modules/vocos/__init__.py @@ -0,0 +1,4 @@ +from .pretrained import Vocos + + +__version__ = "0.1.0" diff --git a/indextts/s2mel/modules/vocos/heads.py b/indextts/s2mel/modules/vocos/heads.py new file mode 100644 index 0000000000000000000000000000000000000000..05917994ebe1c1679fcf2e1d74c8bbae2fbb58ca --- /dev/null +++ b/indextts/s2mel/modules/vocos/heads.py @@ -0,0 +1,164 @@ +from typing import Optional + +import torch +from torch import nn +from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz + +from .spectral_ops import IMDCT, ISTFT +from .modules import symexp + + +class FourierHead(nn.Module): + """Base class for inverse fourier modules.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class ISTFTHead(FourierHead): + """ + ISTFT Head module for predicting STFT complex coefficients. + + Args: + dim (int): Hidden dimension of the model. + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames, which should align with + the resolution of the input features. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): + super().__init__() + out_dim = n_fft + 2 + self.out = torch.nn.Linear(dim, out_dim) + self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ISTFTHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x).transpose(1, 2) + mag, p = x.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes + # wrapping happens here. These two lines produce real and imaginary value + x = torch.cos(p) + y = torch.sin(p) + # recalculating phase here does not produce anything new + # only costs time + # phase = torch.atan2(y, x) + # S = mag * torch.exp(phase * 1j) + # better directly produce the complex value + S = mag * (x + 1j * y) + audio = self.istft(S) + return audio + + +class IMDCTSymExpHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with symmetric exponential function + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized + based on perceptual scaling. Defaults to None. + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__( + self, + dim: int, + mdct_frame_len: int, + padding: str = "same", + sample_rate: Optional[int] = None, + clip_audio: bool = False, + ): + super().__init__() + out_dim = mdct_frame_len // 2 + self.out = nn.Linear(dim, out_dim) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + self.clip_audio = clip_audio + + if sample_rate is not None: + # optionally init the last layer following mel-scale + m_max = _hz_to_mel(sample_rate // 2) + m_pts = torch.linspace(0, m_max, out_dim) + f_pts = _mel_to_hz(m_pts) + scale = 1 - (f_pts / f_pts.max()) + + with torch.no_grad(): + self.out.weight.mul_(scale.view(-1, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTSymExpHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + x = symexp(x) + x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes + audio = self.imdct(x) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + + return audio + + +class IMDCTCosHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p) + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False): + super().__init__() + self.clip_audio = clip_audio + self.out = nn.Linear(dim, mdct_frame_len) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTCosHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + m, p = x.chunk(2, dim=2) + m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes + audio = self.imdct(m * torch.cos(p)) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + return audio diff --git a/indextts/s2mel/modules/vocos/helpers.py b/indextts/s2mel/modules/vocos/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..3d303010352ad59dde2996605f124128ee17db36 --- /dev/null +++ b/indextts/s2mel/modules/vocos/helpers.py @@ -0,0 +1,71 @@ +import matplotlib +import numpy as np +import torch +from matplotlib import pyplot as plt +from pytorch_lightning import Callback + +matplotlib.use("Agg") + + +def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray: + """ + Save a matplotlib figure to a numpy array. + + Args: + fig (Figure): Matplotlib figure object. + + Returns: + ndarray: Numpy array representing the figure. + """ + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray: + """ + Plot a spectrogram and convert it to a numpy array. + + Args: + spectrogram (ndarray): Spectrogram data. + + Returns: + ndarray: Numpy array representing the plotted spectrogram. + """ + spectrogram = spectrogram.astype(np.float32) + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +class GradNormCallback(Callback): + """ + Callback to log the gradient norm. + """ + + def on_after_backward(self, trainer, model): + model.log("grad_norm", gradient_norm(model)) + + +def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor: + """ + Compute the gradient norm. + + Args: + model (Module): PyTorch model. + norm_type (float, optional): Type of the norm. Defaults to 2.0. + + Returns: + Tensor: Gradient norm. + """ + grads = [p.grad for p in model.parameters() if p.grad is not None] + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type) + return total_norm diff --git a/indextts/s2mel/modules/vocos/loss.py b/indextts/s2mel/modules/vocos/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..029f6ac0407d767b94abafa6fdf013755daabf44 --- /dev/null +++ b/indextts/s2mel/modules/vocos/loss.py @@ -0,0 +1,114 @@ +from typing import List, Tuple + +import torch +import torchaudio +from torch import nn + +from vocos.modules import safe_log + + +class MelSpecReconstructionLoss(nn.Module): + """ + L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample + """ + + def __init__( + self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100, + ): + super().__init__() + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1, + ) + + def forward(self, y_hat, y) -> torch.Tensor: + """ + Args: + y_hat (Tensor): Predicted audio waveform. + y (Tensor): Ground truth audio waveform. + + Returns: + Tensor: L1 loss between the mel-scaled magnitude spectrograms. + """ + mel_hat = safe_log(self.mel_spec(y_hat)) + mel = safe_log(self.mel_spec(y)) + + loss = torch.nn.functional.l1_loss(mel, mel_hat) + + return loss + + +class GeneratorLoss(nn.Module): + """ + Generator Loss module. Calculates the loss for the generator based on discriminator outputs. + """ + + def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + disc_outputs (List[Tensor]): List of discriminator outputs. + + Returns: + Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from + the sub-discriminators + """ + loss = torch.zeros(1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype) + gen_losses = [] + for dg in disc_outputs: + l = torch.mean(torch.clamp(1 - dg, min=0)) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +class DiscriminatorLoss(nn.Module): + """ + Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs. + """ + + def forward( + self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor] + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + """ + Args: + disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples. + disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples. + + Returns: + Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from + the sub-discriminators for real outputs, and a list of + loss values for generated outputs. + """ + loss = torch.zeros(1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype) + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean(torch.clamp(1 - dr, min=0)) + g_loss = torch.mean(torch.clamp(1 + dg, min=0)) + loss += r_loss + g_loss + r_losses.append(r_loss) + g_losses.append(g_loss) + + return loss, r_losses, g_losses + + +class FeatureMatchingLoss(nn.Module): + """ + Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators. + """ + + def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor: + """ + Args: + fmap_r (List[List[Tensor]]): List of feature maps from real samples. + fmap_g (List[List[Tensor]]): List of feature maps from generated samples. + + Returns: + Tensor: The calculated feature matching loss. + """ + loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype) + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss diff --git a/indextts/s2mel/modules/vocos/models.py b/indextts/s2mel/modules/vocos/models.py new file mode 100644 index 0000000000000000000000000000000000000000..a9532752c9752851df0206fd0d7bb172ef92505b --- /dev/null +++ b/indextts/s2mel/modules/vocos/models.py @@ -0,0 +1,118 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn.utils import weight_norm + +from .modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm + + +class Backbone(nn.Module): + """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + + Returns: + Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, + and H denotes the model dimension. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class VocosBackbone(Backbone): + """ + Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. + num_layers (int): Number of ConvNeXtBlock layers. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional model. Defaults to None. + """ + + def __init__( + self, + input_channels: int, + dim: int, + intermediate_dim: int, + num_layers: int, + layer_scale_init_value: Optional[float] = None, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + layer_scale_init_value = layer_scale_init_value or 1 / num_layers + self.convnext = nn.ModuleList( + [ + ConvNeXtBlock( + dim=dim, + intermediate_dim=intermediate_dim, + layer_scale_init_value=layer_scale_init_value, + adanorm_num_embeddings=adanorm_num_embeddings, + ) + for _ in range(num_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + bandwidth_id = kwargs.get('bandwidth_id', None) + x = self.embed(x) + if self.adanorm: + assert bandwidth_id is not None + x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) + else: + x = self.norm(x.transpose(1, 2)) + x = x.transpose(1, 2) + for conv_block in self.convnext: + x = conv_block(x, cond_embedding_id=bandwidth_id) + x = self.final_layer_norm(x.transpose(1, 2)) + return x + + +class VocosResNetBackbone(Backbone): + """ + Vocos backbone module built with ResBlocks. + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + num_blocks (int): Number of ResBlock1 blocks. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. + """ + + def __init__( + self, input_channels, dim, num_blocks, layer_scale_init_value=None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)) + layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 + self.resnet = nn.Sequential( + *[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)] + ) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.embed(x) + x = self.resnet(x) + x = x.transpose(1, 2) + return x diff --git a/indextts/s2mel/modules/vocos/modules.py b/indextts/s2mel/modules/vocos/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..af1d6db16e2f10cc9af7bcc64434e2c983d756b7 --- /dev/null +++ b/indextts/s2mel/modules/vocos/modules.py @@ -0,0 +1,213 @@ +from typing import Optional, Tuple + +import torch +from torch import nn +from torch.nn.utils import weight_norm, remove_weight_norm + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: float, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + num_embeddings (int): Number of embeddings. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding_id) + shift = self.shift(cond_embedding_id) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale + shift + return x + + +class ResBlock1(nn.Module): + """ + ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, + but without upsampling layers. + + Args: + dim (int): Number of input channels. + kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. + dilation (tuple[int], optional): Dilation factors for the dilated convolutions. + Defaults to (1, 3, 5). + lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. + Defaults to 0.1. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + dilation: Tuple[int, int, int] = (1, 3, 5), + lrelu_slope: float = 0.1, + layer_scale_init_value: Optional[float] = None, + ): + super().__init__() + self.lrelu_slope = lrelu_slope + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[0], + padding=self.get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[1], + padding=self.get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[2], + padding=self.get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + ] + ) + + self.gamma = nn.ParameterList( + [ + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): + xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) + xt = c2(xt) + if gamma is not None: + xt = gamma * xt + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + @staticmethod + def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) + + +def symlog(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.log1p(x.abs()) + + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) diff --git a/indextts/s2mel/modules/vocos/pretrained.py b/indextts/s2mel/modules/vocos/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..dd8e494c35ac62546594f74ccc2ab623ad92f55c --- /dev/null +++ b/indextts/s2mel/modules/vocos/pretrained.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import Any, Dict, Tuple, Union, Optional + +import torch +import yaml +from torch import nn +from .heads import ISTFTHead +from .models import VocosBackbone + + +class Vocos(nn.Module): + """ + The Vocos class represents a Fourier-based neural vocoder for audio synthesis. + This class is primarily designed for inference, with support for loading from pretrained + model checkpoints. It consists of three main components: a feature extractor, + a backbone, and a head. + """ + + def __init__( + self, args, + ): + super().__init__() + self.backbone = VocosBackbone( + input_channels=args.vocos.backbone.input_channels, + dim=args.vocos.backbone.dim, + intermediate_dim=args.vocos.backbone.intermediate_dim, + num_layers=args.vocos.backbone.num_layers, + ) + self.head = ISTFTHead( + dim=args.vocos.head.dim, + n_fft=args.vocos.head.n_fft, + hop_length=args.vocos.head.hop_length, + padding=args.vocos.head.padding, + ) + + def forward(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: + """ + Method to decode audio waveform from already calculated features. The features input is passed through + the backbone and the head to reconstruct the audio output. + + Args: + features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size, + C denotes the feature dimension, and L is the sequence length. + + Returns: + Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). + """ + x = self.backbone(features_input, **kwargs) + audio_output = self.head(x) + return audio_output diff --git a/indextts/s2mel/modules/vocos/spectral_ops.py b/indextts/s2mel/modules/vocos/spectral_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a8eda1c8e18a32406aad40415f6a8bf60eb15fea --- /dev/null +++ b/indextts/s2mel/modules/vocos/spectral_ops.py @@ -0,0 +1,192 @@ +import numpy as np +import scipy +import torch +from torch import nn, view_as_real, view_as_complex + + +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + assert (window_envelope > 1e-11).all() + y = y / window_envelope + + return y + + +class MDCT(nn.Module): + """ + Modified Discrete Cosine Transform (MDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len) + post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N) + # view_as_real: NCCL Backend does not support ComplexFloat data type + # https://github.com/pytorch/pytorch/issues/71613 + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + """ + Apply the Modified Discrete Cosine Transform (MDCT) to the input audio. + + Args: + audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size + and T is the length of the audio. + + Returns: + Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames + and N is the number of frequency bins. + """ + if self.padding == "center": + audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2)) + elif self.padding == "same": + # hop_length is 1/2 frame_len + audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4)) + else: + raise ValueError("Padding must be 'center' or 'same'.") + + x = audio.unfold(-1, self.frame_len, self.frame_len // 2) + N = self.frame_len // 2 + x = x * self.window.expand(x.shape) + X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N] + res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N) + return torch.real(res) * np.sqrt(2) + + +class IMDCT(nn.Module): + """ + Inverse Modified Discrete Cosine Transform (IMDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N) + post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2)) + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + """ + Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients. + + Args: + X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size, + L is the number of frames, and N is the number of frequency bins. + + Returns: + Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio. + """ + B, L, N = X.shape + Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device) + Y[..., :N] = X + Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,))) + y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1) + y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2) + result = y * self.window.expand(y.shape) + output_size = (1, (L + 1) * N) + audio = torch.nn.functional.fold( + result.transpose(1, 2), + output_size=output_size, + kernel_size=(1, self.frame_len), + stride=(1, self.frame_len // 2), + )[:, 0, 0, :] + + if self.padding == "center": + pad = self.frame_len // 2 + elif self.padding == "same": + pad = self.frame_len // 4 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + audio = audio[:, pad:-pad] + return audio diff --git a/indextts/s2mel/modules/wavenet.py b/indextts/s2mel/modules/wavenet.py new file mode 100644 index 0000000000000000000000000000000000000000..3989be13bd0097410037a9685932e851584bc376 --- /dev/null +++ b/indextts/s2mel/modules/wavenet.py @@ -0,0 +1,174 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from indextts.s2mel.modules.encodec import SConv1d + +from . import commons +LRELU_SLOPE = 0.1 + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dialted and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size ** i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, + groups=channels, dilation=dilation, padding=padding + )) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0, causal=False): + super(WN, self).__init__() + conv1d_type = SConv1d + assert (kernel_size % 2 == 1) + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size, + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + self.cond_layer = conv1d_type(gin_channels, 2 * hidden_channels * n_layers, 1, norm='weight_norm') + + for i in range(n_layers): + dilation = dilation_rate ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = conv1d_type(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, + padding=padding, norm='weight_norm', causal=causal) + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = conv1d_type(hidden_channels, res_skip_channels, 1, norm='weight_norm', causal=causal) + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply( + x_in, + g_l, + n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, :self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels:, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) diff --git a/indextts/s2mel/optimizers.py b/indextts/s2mel/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..979ed1b35009a675de966a267434bd15601e64c0 --- /dev/null +++ b/indextts/s2mel/optimizers.py @@ -0,0 +1,96 @@ +#coding:utf-8 +import os, sys +import os.path as osp +import numpy as np +import torch +from torch import nn +from torch.optim import Optimizer +from functools import reduce +from torch.optim import AdamW + +class MultiOptimizer: + def __init__(self, optimizers={}, schedulers={}): + self.optimizers = optimizers + self.schedulers = schedulers + self.keys = list(optimizers.keys()) + self.param_groups = reduce(lambda x,y: x+y, [v.param_groups for v in self.optimizers.values()]) + + def state_dict(self): + state_dicts = [(key, self.optimizers[key].state_dict())\ + for key in self.keys] + return state_dicts + + def scheduler_state_dict(self): + state_dicts = [(key, self.schedulers[key].state_dict())\ + for key in self.keys] + return state_dicts + + def load_state_dict(self, state_dict): + for key, val in state_dict: + try: + self.optimizers[key].load_state_dict(val) + except: + print("Unloaded %s" % key) + + def load_scheduler_state_dict(self, state_dict): + for key, val in state_dict: + try: + self.schedulers[key].load_state_dict(val) + except: + print("Unloaded %s" % key) + + def step(self, key=None, scaler=None): + keys = [key] if key is not None else self.keys + _ = [self._step(key, scaler) for key in keys] + + def _step(self, key, scaler=None): + if scaler is not None: + scaler.step(self.optimizers[key]) + scaler.update() + else: + self.optimizers[key].step() + + def zero_grad(self, key=None): + if key is not None: + self.optimizers[key].zero_grad() + else: + _ = [self.optimizers[key].zero_grad() for key in self.keys] + + def scheduler(self, *args, key=None): + if key is not None: + self.schedulers[key].step(*args) + else: + _ = [self.schedulers[key].step_batch(*args) for key in self.keys] + +def define_scheduler(optimizer, params): + scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params['gamma']) + + return scheduler + +def build_optimizer(model_dict, lr, type='AdamW'): + optim = {} + for key, model in model_dict.items(): + model_parameters = model.parameters() + parameters_names = [] + parameters_names.append( + [ + name_param_pair[0] + for name_param_pair in model.named_parameters() + ] + ) + if type == 'AdamW': + optim[key] = AdamW( + model_parameters, + lr=lr, + betas=(0.9, 0.98), + eps=1e-9, + weight_decay=0.1, + ) + else: + raise ValueError('Unknown optimizer type: %s' % type) + + schedulers = dict([(key, torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.999996)) + for key, opt in optim.items()]) + + multi_optim = MultiOptimizer(optim, schedulers) + return multi_optim \ No newline at end of file diff --git a/indextts/s2mel/wav2vecbert_extract.py b/indextts/s2mel/wav2vecbert_extract.py new file mode 100644 index 0000000000000000000000000000000000000000..2aa845d68b606562c692f138602daf761cad33ef --- /dev/null +++ b/indextts/s2mel/wav2vecbert_extract.py @@ -0,0 +1,148 @@ +from transformers import SeamlessM4TFeatureExtractor +from transformers import Wav2Vec2BertModel +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import librosa +import os +import pickle +import math +import json +import safetensors +import json5 +# from codec.kmeans.repcodec_model import RepCodec +from startts.examples.ftchar.models.codec.kmeans.repcodec_model import RepCodec + +class JsonHParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = JsonHParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + + +def _load_config(config_fn, lowercase=False): + """Load configurations into a dictionary + + Args: + config_fn (str): path to configuration file + lowercase (bool, optional): whether changing keys to lower case. Defaults to False. + + Returns: + dict: dictionary that stores configurations + """ + with open(config_fn, "r") as f: + data = f.read() + config_ = json5.loads(data) + if "base_config" in config_: + # load configurations from new path + p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"]) + p_config_ = _load_config(p_config_path) + config_ = override_config(p_config_, config_) + if lowercase: + # change keys in config_ to lower case + config_ = get_lowercase_keys_config(config_) + return config_ + + +def load_config(config_fn, lowercase=False): + """Load configurations into a dictionary + + Args: + config_fn (str): path to configuration file + lowercase (bool, optional): _description_. Defaults to False. + + Returns: + JsonHParams: an object that stores configurations + """ + config_ = _load_config(config_fn, lowercase=lowercase) + # create an JsonHParams object with configuration dict + cfg = JsonHParams(**config_) + return cfg + +class Extract_wav2vectbert: + def __init__(self,device): + #semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0") + self.semantic_model = Wav2Vec2BertModel.from_pretrained("./MaskGCT_model/w2v_bert/") + self.semantic_model.eval() + self.semantic_model.to(device) + self.stat_mean_var = torch.load("./MaskGCT_model/wav2vec2bert_stats.pt") + self.semantic_mean = self.stat_mean_var["mean"] + self.semantic_std = torch.sqrt(self.stat_mean_var["var"]) + self.semantic_mean = self.semantic_mean.to(device) + self.semantic_std = self.semantic_std.to(device) + self.processor = SeamlessM4TFeatureExtractor.from_pretrained( + "./MaskGCT_model/w2v_bert/") + self.device = device + + cfg_maskgct = load_config('./MaskGCT_model/maskgct.json') + cfg = cfg_maskgct.model.semantic_codec + self.semantic_code_ckpt = r'./MaskGCT_model/semantic_codec/model.safetensors' + self.semantic_codec = RepCodec(cfg=cfg) + self.semantic_codec.eval() + self.semantic_codec.to(device) + safetensors.torch.load_model(self.semantic_codec, self.semantic_code_ckpt) + + @torch.no_grad() + def extract_features(self, speech): # speech [b,T] + inputs = self.processor(speech, sampling_rate=16000, return_tensors="pt") + input_features = inputs["input_features"] + attention_mask = inputs["attention_mask"] + return input_features, attention_mask #[2, 620, 160] [2, 620] + + @torch.no_grad() + def extract_semantic_code(self, input_features, attention_mask): + vq_emb = self.semantic_model( # Wav2Vec2BertModel + input_features=input_features, + attention_mask=attention_mask, + output_hidden_states=True, + ) + feat = vq_emb.hidden_states[17] # (B, T, C) + feat = (feat - self.semantic_mean.to(feat)) / self.semantic_std.to(feat) + + semantic_code, rec_feat = self.semantic_codec.quantize(feat) # (B, T) + return semantic_code, rec_feat + + def feature_extract(self, prompt_speech): + + input_features, attention_mask = self.extract_features(prompt_speech) + input_features = input_features.to(self.device) + attention_mask = attention_mask.to(self.device) + semantic_code, rec_feat = self.extract_semantic_code(input_features, attention_mask) + return semantic_code,rec_feat + +if __name__=='__main__': + speech_path = 'test/magi1.wav' + speech = librosa.load(speech_path, sr=16000)[0] + speech = np.c_[speech,speech,speech].T #[2, 198559] + print(speech.shape) + + Extract_feature = Extract_wav2vectbert('cuda:0') + semantic_code,rec_feat = Extract_feature.feature_extract(speech) + print(semantic_code.shape,rec_feat.shape) + diff --git a/indextts/utils/__init__.py b/indextts/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indextts/utils/arch_util.py b/indextts/utils/arch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..01eac108e8fdec0cc34640eacc61689ccf571d61 --- /dev/null +++ b/indextts/utils/arch_util.py @@ -0,0 +1,120 @@ +import math + +import torch +import torch.nn as nn + +from indextts.utils.xtransformers import RelativePositionBias + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + groups = 32 + if channels <= 16: + groups = 8 + elif channels <= 64: + groups = 16 + while channels % groups != 0: + groups = int(groups / 2) + assert groups > 2 + return GroupNorm32(groups, channels) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, mask=None, rel_pos=None): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + if rel_pos is not None: + weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1]) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + if mask is not None: + # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. + mask = mask.repeat(self.n_heads, 1).unsqueeze(1) + weight = weight * mask + a = torch.einsum("bts,bcs->bct", weight, v) + + return a.reshape(bs, -1, length) + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + do_checkpoint=True, + relative_pos_embeddings=False, + ): + super().__init__() + self.channels = channels + self.do_checkpoint = do_checkpoint + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.norm = normalization(channels) + self.qkv = nn.Conv1d(channels, channels * 3, 1) + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) + if relative_pos_embeddings: + self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) + else: + self.relative_pos_embeddings = None + + def forward(self, x, mask=None): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv, mask, self.relative_pos_embeddings) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) diff --git a/indextts/utils/checkpoint.py b/indextts/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e34d496cb1bf056bc1c4a3756e1e502ed927dc --- /dev/null +++ b/indextts/utils/checkpoint.py @@ -0,0 +1,34 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import logging +import os +import re +from collections import OrderedDict + +import torch +import yaml + + +def load_checkpoint(model: torch.nn.Module, model_pth: str) -> dict: + checkpoint = torch.load(model_pth, map_location='cpu') + checkpoint = checkpoint['model'] if 'model' in checkpoint else checkpoint + model.load_state_dict(checkpoint, strict=True) + info_path = re.sub('.pth$', '.yaml', model_pth) + configs = {} + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + return configs diff --git a/indextts/utils/common.py b/indextts/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..eb2f89a09b5ffb160b120058abed55839e1b03b5 --- /dev/null +++ b/indextts/utils/common.py @@ -0,0 +1,121 @@ +import os +import random +import re + +import torch +import torchaudio + +MATPLOTLIB_FLAG = False + + +def load_audio(audiopath, sampling_rate): + audio, sr = torchaudio.load(audiopath) + # print(f"wave shape: {audio.shape}, sample_rate: {sr}") + + if audio.size(0) > 1: # mix to mono + audio = audio[0].unsqueeze(0) + + if sr != sampling_rate: + try: + audio = torchaudio.functional.resample(audio, sr, sampling_rate) + except Exception as e: + print(f"Warning: {audiopath}, wave shape: {audio.shape}, sample_rate: {sr}") + return None + # clip audio invalid values + audio.clip_(-1, 1) + return audio + + +def tokenize_by_CJK_char(line: str, do_upper_case=True) -> str: + """ + Tokenize a line of text with CJK char. + + Note: All return charaters will be upper case. + + Example: + input = "你好世界是 hello world 的中文" + output = "你 好 世 界 是 HELLO WORLD 的 中 文" + + Args: + line: + The input text. + + Return: + A new string tokenize by CJK char. + """ + # The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py + CJK_RANGE_PATTERN = ( + r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])" + ) + chars = re.split(CJK_RANGE_PATTERN, line.strip()) + return " ".join([w.strip().upper() if do_upper_case else w.strip() for w in chars if w.strip()]) + + +def de_tokenized_by_CJK_char(line: str, do_lower_case=False) -> str: + """ + Example: + input = "你 好 世 界 是 HELLO WORLD 的 中 文" + output = "你好世界是 hello world 的中文" + + do_lower_case: + input = "SEE YOU!" + output = "see you!" + """ + # replace english words in the line with placeholders + english_word_pattern = re.compile(r"([A-Z]+(?:[\s-][A-Z-]+)*)", re.IGNORECASE) + english_sents = english_word_pattern.findall(line) + for i, sent in enumerate(english_sents): + line = line.replace(sent, f"") + + words = line.split() + # restore english sentences + sent_placeholder_pattern = re.compile(r"^.*?()") + for i in range(len(words)): + m = sent_placeholder_pattern.match(words[i]) + if m: + # restore the english word + placeholder_index = int(m.group(2)) + words[i] = words[i].replace(m.group(1), english_sents[placeholder_index]) + if do_lower_case: + words[i] = words[i].lower() + return "".join(words) + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) diff --git a/indextts/utils/feature_extractors.py b/indextts/utils/feature_extractors.py new file mode 100644 index 0000000000000000000000000000000000000000..c3af3e01d1bbdb06fac465e570ac95335ccf0fe5 --- /dev/null +++ b/indextts/utils/feature_extractors.py @@ -0,0 +1,50 @@ +import torch +import torchaudio +from torch import nn +from indextts.utils.common import safe_log + + +class FeatureExtractor(nn.Module): + """Base class for feature extractors.""" + + def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Extract features from the given audio. + + Args: + audio (Tensor): Input audio waveform. + + Returns: + Tensor: Extracted features of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class MelSpectrogramFeatures(FeatureExtractor): + def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, win_length=None, + n_mels=100, mel_fmin=0, mel_fmax=None, normalize=False, padding="center"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + power=1, + normalized=normalize, + f_min=mel_fmin, + f_max=mel_fmax, + n_mels=n_mels, + center=padding == "center", + ) + + def forward(self, audio, **kwargs): + if self.padding == "same": + pad = self.mel_spec.win_length - self.mel_spec.hop_length + audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect") + mel = self.mel_spec(audio) + mel = safe_log(mel) + return mel diff --git a/indextts/utils/front.py b/indextts/utils/front.py new file mode 100644 index 0000000000000000000000000000000000000000..e6c2e44a1e50a4b672bed79d7b46063842d4f09d --- /dev/null +++ b/indextts/utils/front.py @@ -0,0 +1,536 @@ +# -*- coding: utf-8 -*- +import os +import traceback +import re +from typing import List, Union, overload +import warnings +from indextts.utils.common import tokenize_by_CJK_char, de_tokenized_by_CJK_char +from sentencepiece import SentencePieceProcessor + + +class TextNormalizer: + def __init__(self): + self.zh_normalizer = None + self.en_normalizer = None + self.char_rep_map = { + ":": ",", + ";": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": " ", + "·": "-", + "、": ",", + "...": "…", + ",,,": "…", + ",,,": "…", + "……": "…", + "“": "'", + "”": "'", + '"': "'", + "‘": "'", + "’": "'", + "(": "'", + ")": "'", + "(": "'", + ")": "'", + "《": "'", + "》": "'", + "【": "'", + "】": "'", + "[": "'", + "]": "'", + "—": "-", + "~": "-", + "~": "-", + "「": "'", + "」": "'", + ":": ",", + } + self.zh_char_rep_map = { + "$": ".", + **self.char_rep_map, + } + + def match_email(self, email): + # 正则表达式匹配邮箱格式:数字英文@数字英文.英文 + pattern = r"^[a-zA-Z0-9]+@[a-zA-Z0-9]+\.[a-zA-Z]+$" + return re.match(pattern, email) is not None + + PINYIN_TONE_PATTERN = r"(? str: + if not self.zh_normalizer or not self.en_normalizer: + print("Error, text normalizer is not initialized !!!") + return "" + if self.use_chinese(text): + text = re.sub(TextNormalizer.ENGLISH_CONTRACTION_PATTERN, r"\1 is", text, flags=re.IGNORECASE) + replaced_text, pinyin_list = self.save_pinyin_tones(text.rstrip()) + + replaced_text, original_name_list = self.save_names(replaced_text) + try: + result = self.zh_normalizer.normalize(replaced_text) + except Exception: + result = "" + print(traceback.format_exc()) + # 恢复人名 + result = self.restore_names(result, original_name_list) + # 恢复拼音声调 + result = self.restore_pinyin_tones(result, pinyin_list) + pattern = re.compile("|".join(re.escape(p) for p in self.zh_char_rep_map.keys())) + result = pattern.sub(lambda x: self.zh_char_rep_map[x.group()], result) + else: + try: + text = re.sub(TextNormalizer.ENGLISH_CONTRACTION_PATTERN, r"\1 is", text, flags=re.IGNORECASE) + result = self.en_normalizer.normalize(text) + except Exception: + result = text + print(traceback.format_exc()) + pattern = re.compile("|".join(re.escape(p) for p in self.char_rep_map.keys())) + result = pattern.sub(lambda x: self.char_rep_map[x.group()], result) + return result + + def correct_pinyin(self, pinyin: str): + """ + 将 jqx 的韵母为 u/ü 的拼音转换为 v + 如:ju -> jv , que -> qve, xün -> xvn + """ + if pinyin[0] not in "jqxJQX": + return pinyin + # 匹配 jqx 的韵母为 u/ü 的拼音 + pattern = r"([jqx])[uü](n|e|an)*(\d)" + repl = r"\g<1>v\g<2>\g<3>" + pinyin = re.sub(pattern, repl, pinyin, flags=re.IGNORECASE) + return pinyin.upper() + + def save_names(self, original_text): + """ + 替换人名为占位符 , ... + 例如:克里斯托弗·诺兰 -> + """ + # 人名 + name_pattern = re.compile(TextNormalizer.NAME_PATTERN, re.IGNORECASE) + original_name_list = re.findall(name_pattern, original_text) + if len(original_name_list) == 0: + return (original_text, None) + original_name_list = list(set("".join(n) for n in original_name_list)) + transformed_text = original_text + # 替换占位符 , ... + for i, name in enumerate(original_name_list): + number = chr(ord("a") + i) + transformed_text = transformed_text.replace(name, f"") + + return transformed_text, original_name_list + + def restore_names(self, normalized_text, original_name_list): + """ + 恢复人名为原来的文字 + 例如: -> original_name_list[0] + """ + if not original_name_list or len(original_name_list) == 0: + return normalized_text + + transformed_text = normalized_text + # 替换为占位符 , ... + for i, name in enumerate(original_name_list): + number = chr(ord("a") + i) + transformed_text = transformed_text.replace(f"", name) + return transformed_text + + def save_pinyin_tones(self, original_text): + """ + 替换拼音声调为占位符 , , ... + 例如:xuan4 -> + """ + # 声母韵母+声调数字 + origin_pinyin_pattern = re.compile(TextNormalizer.PINYIN_TONE_PATTERN, re.IGNORECASE) + original_pinyin_list = re.findall(origin_pinyin_pattern, original_text) + if len(original_pinyin_list) == 0: + return (original_text, None) + original_pinyin_list = list(set("".join(p) for p in original_pinyin_list)) + transformed_text = original_text + # 替换为占位符 , , ... + for i, pinyin in enumerate(original_pinyin_list): + number = chr(ord("a") + i) + transformed_text = transformed_text.replace(pinyin, f"") + + # print("original_text: ", original_text) + # print("transformed_text: ", transformed_text) + return transformed_text, original_pinyin_list + + def restore_pinyin_tones(self, normalized_text, original_pinyin_list): + """ + 恢复拼音中的音调数字(1-5)为原来的拼音 + 例如: -> original_pinyin_list[0] + """ + if not original_pinyin_list or len(original_pinyin_list) == 0: + return normalized_text + + transformed_text = normalized_text + # 替换占位符 , , ... + for i, pinyin in enumerate(original_pinyin_list): + number = chr(ord("a") + i) + pinyin = self.correct_pinyin(pinyin) + transformed_text = transformed_text.replace(f"", pinyin) + # print("normalized_text: ", normalized_text) + # print("transformed_text: ", transformed_text) + return transformed_text + + +class TextTokenizer: + def __init__(self, vocab_file: str, normalizer: TextNormalizer = None): + self.vocab_file = vocab_file + self.normalizer = normalizer + + if self.vocab_file is None: + raise ValueError("vocab_file is None") + if not os.path.exists(self.vocab_file): + raise ValueError(f"vocab_file {self.vocab_file} does not exist") + if self.normalizer: + self.normalizer.load() + # 加载词表 + self.sp_model = SentencePieceProcessor(model_file=self.vocab_file) + + self.pre_tokenizers = [ + # 预处理器 + tokenize_by_CJK_char, + ] + + @property + def vocab_size(self): + return self.sp_model.GetPieceSize() + + @property + def unk_token(self): + return "" + + @property + def pad_token(self): + return None + + @property + def bos_token(self): + return "" + + @property + def eos_token(self): + return "" + + @property + def pad_token_id(self): + return -1 + + @property + def bos_token_id(self): + return 0 + + @property + def eos_token_id(self): + return 1 + + @property + def unk_token_id(self): + return self.sp_model.unk_id() + + @property + def special_tokens_map(self): + return { + "unk_token": self.unk_token, + "pad_token": self.pad_token, + "bos_token": self.bos_token, + "eos_token": self.eos_token, + } + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + return vocab + + @overload + def convert_ids_to_tokens(self, ids: int) -> str: ... + + @overload + def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: ... + + def convert_ids_to_tokens(self, ids: Union[List[int], int]): + return self.sp_model.IdToPiece(ids) + + def convert_tokens_to_ids(self, tokens: Union[List[str], str]) -> List[int]: + if isinstance(tokens, str): + tokens = [tokens] + return [self.sp_model.PieceToId(token) for token in tokens] + + def tokenize(self, text: str) -> List[str]: + return self.encode(text, out_type=str) + + def encode(self, text: str, **kwargs): + if len(text) == 0: + return [] + if len(text.strip()) == 1: + return self.sp_model.Encode(text, out_type=kwargs.pop("out_type", int), **kwargs) + # 预处理 + if self.normalizer: + text = self.normalizer.normalize(text) + if len(self.pre_tokenizers) > 0: + for pre_tokenizer in self.pre_tokenizers: + text = pre_tokenizer(text) + return self.sp_model.Encode(text, out_type=kwargs.pop("out_type", int), **kwargs) + + def batch_encode(self, texts: List[str], **kwargs): + # 预处理 + if self.normalizer: + texts = [self.normalizer.normalize(text) for text in texts] + if len(self.pre_tokenizers) > 0: + for pre_tokenizer in self.pre_tokenizers: + texts = [pre_tokenizer(text) for text in texts] + return self.sp_model.Encode(texts, out_type=kwargs.pop("out_type", int), **kwargs) + + def decode(self, ids: Union[List[int], int], do_lower_case=False, **kwargs): + if isinstance(ids, int): + ids = [ids] + decoded = self.sp_model.Decode(ids, out_type=kwargs.pop("out_type", str), **kwargs) + return de_tokenized_by_CJK_char(decoded, do_lower_case=do_lower_case) + + @staticmethod + def split_sentences_by_token( + tokenized_str: List[str], split_tokens: List[str], max_tokens_per_sentence: int + ) -> List[List[str]]: + """ + 将tokenize后的结果按特定token进一步分割 + """ + # 处理特殊情况 + if len(tokenized_str) == 0: + return [] + sentences: List[List[str]] = [] + current_sentence = [] + current_sentence_tokens_len = 0 + for i in range(len(tokenized_str)): + token = tokenized_str[i] + current_sentence.append(token) + current_sentence_tokens_len += 1 + if current_sentence_tokens_len <= max_tokens_per_sentence: + if token in split_tokens and current_sentence_tokens_len > 2: + if i < len(tokenized_str) - 1: + if tokenized_str[i + 1] in ["'", "▁'"]: + # 后续token是',则不切分 + current_sentence.append(tokenized_str[i + 1]) + i += 1 + sentences.append(current_sentence) + current_sentence = [] + current_sentence_tokens_len = 0 + continue + # 如果当前tokens的长度超过最大限制 + if not ("," in split_tokens or "▁," in split_tokens ) and ("," in current_sentence or "▁," in current_sentence): + # 如果当前tokens中有,,则按,分割 + sub_sentences = TextTokenizer.split_sentences_by_token( + current_sentence, [",", "▁,"], max_tokens_per_sentence=max_tokens_per_sentence + ) + elif "-" not in split_tokens and "-" in current_sentence: + # 没有,,则按-分割 + sub_sentences = TextTokenizer.split_sentences_by_token( + current_sentence, ["-"], max_tokens_per_sentence=max_tokens_per_sentence + ) + else: + # 按照长度分割 + sub_sentences = [] + for j in range(0, len(current_sentence), max_tokens_per_sentence): + if j + max_tokens_per_sentence < len(current_sentence): + sub_sentences.append(current_sentence[j : j + max_tokens_per_sentence]) + else: + sub_sentences.append(current_sentence[j:]) + warnings.warn( + f"The tokens length of sentence exceeds limit: {max_tokens_per_sentence}, " + f"Tokens in sentence: {current_sentence}." + "Maybe unexpected behavior", + RuntimeWarning, + ) + sentences.extend(sub_sentences) + current_sentence = [] + current_sentence_tokens_len = 0 + if current_sentence_tokens_len > 0: + assert current_sentence_tokens_len <= max_tokens_per_sentence + sentences.append(current_sentence) + # 如果相邻的句子加起来长度小于最大限制,则合并 + merged_sentences = [] + for sentence in sentences: + if len(sentence) == 0: + continue + if len(merged_sentences) == 0: + merged_sentences.append(sentence) + elif len(merged_sentences[-1]) + len(sentence) <= max_tokens_per_sentence: + merged_sentences[-1] = merged_sentences[-1] + sentence + else: + merged_sentences.append(sentence) + return merged_sentences + + punctuation_marks_tokens = [ + ".", + "!", + "?", + "▁.", + # "▁!", # unk + "▁?", + "▁...", # ellipsis + ] + def split_sentences(self, tokenized: List[str], max_tokens_per_sentence=120) -> List[List[str]]: + return TextTokenizer.split_sentences_by_token( + tokenized, self.punctuation_marks_tokens, max_tokens_per_sentence=max_tokens_per_sentence + ) + + +if __name__ == "__main__": + # 测试程序 + + text_normalizer = TextNormalizer() + + cases = [ + "IndexTTS 正式发布1.0版本了,效果666", + "晕XUAN4是一种GAN3觉", + "我爱你!", + "I love you!", + "“我爱你”的英语是“I love you”", + "2.5平方电线", + "共465篇,约315万字", + "2002年的第一场雪,下在了2003年", + "速度是10km/h", + "现在是北京时间2025年01月11日 20:00", + "他这条裤子是2012年买的,花了200块钱", + "电话:135-4567-8900", + "1键3连", + "他这条视频点赞3000+,评论1000+,收藏500+", + "这是1024元的手机,你要吗?", + "受不liao3你了", + "“衣裳”不读衣chang2,而是读衣shang5", + "最zhong4要的是:不要chong2蹈覆辙", + "不zuo1死就不会死", + "See you at 8:00 AM", + "8:00 AM 开会", + "Couting down 3, 2, 1, go!", + "数到3就开始:1、2、3", + "This sales for 2.5% off, only $12.5.", + "5G网络是4G网络的升级版,2G网络是3G网络的前身", + "苹果于2030/1/2发布新 iPhone 2X 系列手机,最低售价仅 ¥12999", + "这酒...里...有毒...", + # 异常case + "只有,,,才是最好的", + "babala2是什么?", # babala二是什么? + "用beta1测试", # 用beta一测试 + "have you ever been to beta2?", # have you ever been to beta two? + "such as XTTS, CosyVoice2, Fish-Speech, and F5-TTS", # such as xtts,cosyvoice two,fish-speech,and f five-tts + "where's the money?", # where is the money? + "who's there?", # who is there? + "which's the best?", # which is the best? + "how's it going?", # how is it going? + "今天是个好日子 it's a good day", # 今天是个好日子 it is a good day + # 人名 + "约瑟夫·高登-莱维特(Joseph Gordon-Levitt is an American actor)", + "蒂莫西·唐纳德·库克(英文名:Timothy Donald Cook),通称蒂姆·库克(Tim Cook),美国商业经理、工业工程师和工业开发商,现任苹果公司首席执行官。", + # 长句子 + "《盗梦空间》是由美国华纳兄弟影片公司出品的电影,由克里斯托弗·诺兰执导并编剧,莱昂纳多·迪卡普里奥、玛丽昂·歌迪亚、约瑟夫·高登-莱维特、艾利奥特·佩吉、汤姆·哈迪等联袂主演,2010年7月16日在美国上映,2010年9月1日在中国内地上映,2020年8月28日在中国内地重映。影片剧情游走于梦境与现实之间,被定义为“发生在意识结构内的当代动作科幻片”,讲述了由莱昂纳多·迪卡普里奥扮演的造梦师,带领特工团队进入他人梦境,从他人的潜意识中盗取机密,并重塑他人梦境的故事。", + "清晨拉开窗帘,阳光洒在窗台的Bloomixy花艺礼盒上——薰衣草香薰蜡烛唤醒嗅觉,永生花束折射出晨露般光泽。设计师将“自然绽放美学”融入每个细节:手工陶瓷花瓶可作首饰收纳,香薰精油含依兰依兰舒缓配方。限量款附赠《365天插花灵感手册》,让每个平凡日子都有花开仪式感。\n宴会厅灯光暗下的刹那,Glimmeria星月系列耳坠开始发光——瑞士冷珐琅工艺让蓝宝石如银河流动,钛合金骨架仅3.2g无负重感。设计师秘密:内置微型重力感应器,随步伐产生0.01mm振幅,打造“行走的星光”。七夕限定礼盒含星座定制铭牌,让爱意如星辰永恒闪耀。", + "电影1:“黑暗骑士”(演员:克里斯蒂安·贝尔、希斯·莱杰;导演:克里斯托弗·诺兰);电影2:“盗梦空间”(演员:莱昂纳多·迪卡普里奥;导演:克里斯托弗·诺兰);电影3:“钢琴家”(演员:艾德里安·布洛迪;导演:罗曼·波兰斯基);电影4:“泰坦尼克号”(演员:莱昂纳多·迪卡普里奥;导演:詹姆斯·卡梅隆);电影5:“阿凡达”(演员:萨姆·沃辛顿;导演:詹姆斯·卡梅隆);电影6:“南方公园:大电影”(演员:马特·斯通、托马斯·艾恩格瑞;导演:特雷·帕克)", + ] + # 测试分词器 + tokenizer = TextTokenizer( + vocab_file="checkpoints/bpe.model", + normalizer=text_normalizer, + ) + + codes = tokenizer.batch_encode( + cases, + out_type=int, + ) + + print(f"vocab_size: {tokenizer.vocab_size}") + # print(f"pad_token: {tokenizer.pad_token}, pad_token_id: {tokenizer.pad_token_id}") + print(f"bos_token: {tokenizer.bos_token}, bos_token_id: {tokenizer.bos_token_id}") + print(f"eos_token: {tokenizer.eos_token}, eos_token_id: {tokenizer.eos_token_id}") + print(f"unk_token: {tokenizer.unk_token}, unk_token_id: {tokenizer.unk_token_id}") + # 测试拼音 (8474-10201) + for id in range(8474, 10201): + pinyin = tokenizer.convert_ids_to_tokens(id) + if re.match(TextNormalizer.PINYIN_TONE_PATTERN, pinyin, re.IGNORECASE) is None: + print(f"{pinyin} should be matched") + for badcase in [ + "beta1", "better1", "voice2", "bala2", "babala2", "hunger2" + ]: + if re.match(TextNormalizer.PINYIN_TONE_PATTERN, badcase, re.IGNORECASE) is not None: + print(f"{badcase} should not be matched!") + # 不应该有 unk_token_id + for t in set([*TextTokenizer.punctuation_marks_tokens, ",", "▁,", "-", "▁..."]): + tokens = tokenizer.convert_tokens_to_ids(t) + if tokenizer.unk_token_id in tokens: + print(f"Warning: {t} is unknown token") + print(f"`{t}`", "->", tokens, "->", tokenizer.convert_ids_to_tokens(tokens)) + for ch in set(tokenizer.normalizer.zh_char_rep_map.values()): + # 测试 normalize后的字符能被分词器识别 + print(f"`{ch}`", "->", tokenizer.sp_model.Encode(ch, out_type=str)) + print(f"` {ch}`", "->", tokenizer.sp_model.Encode(f" {ch}", out_type=str)) + max_tokens_per_sentence=120 + for i in range(len(cases)): + print(f"原始文本: {cases[i]}") + print(f"Normalized: {text_normalizer.normalize(cases[i])}") + tokens = tokenizer.tokenize(cases[i]) + print("Tokenzied: ", ", ".join([f"`{t}`" for t in tokens])) + sentences = tokenizer.split_sentences(tokens, max_tokens_per_sentence=max_tokens_per_sentence) + print("Splitted sentences count:", len(sentences)) + if len(sentences) > 1: + for j in range(len(sentences)): + print(f" {j}, count:", len(sentences[j]), ", tokens:", "".join(sentences[j])) + if len(sentences[j]) > max_tokens_per_sentence: + print(f"Warning: sentence {j} is too long, length: {len(sentences[j])}") + #print(f"Token IDs (first 10): {codes[i][:10]}") + if tokenizer.unk_token in codes[i]: + print(f"Warning: `{cases[i]}` contains UNKNOWN token") + print(f"Decoded: {tokenizer.decode(codes[i], do_lower_case=True)}") + print("-" * 50) diff --git a/indextts/utils/maskgct/models/codec/__init__.py b/indextts/utils/maskgct/models/codec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indextts/utils/maskgct/models/codec/amphion_codec/codec.py b/indextts/utils/maskgct/models/codec/amphion_codec/codec.py new file mode 100644 index 0000000000000000000000000000000000000000..65d62344750044b4a19ba4bfc2b52de9daa46dda --- /dev/null +++ b/indextts/utils/maskgct/models/codec/amphion_codec/codec.py @@ -0,0 +1,427 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from indextts.utils.maskgct.models.codec.amphion_codec.quantize import ( + ResidualVQ, + VectorQuantize, + FactorizedVectorQuantize, + LookupFreeQuantize, +) + +from indextts.utils.maskgct.models.codec.amphion_codec.vocos import Vocos + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Snake1d(dim // 2), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class CodecEncoder(nn.Module): + def __init__( + self, + d_model: int = 64, + up_ratios: list = [4, 5, 5, 6], + out_channels: int = 256, + use_tanh: bool = False, + cfg=None, + ): + super().__init__() + + d_model = cfg.d_model if cfg is not None else d_model + up_ratios = cfg.up_ratios if cfg is not None else up_ratios + out_channels = cfg.out_channels if cfg is not None else out_channels + use_tanh = cfg.use_tanh if cfg is not None else use_tanh + + # Create first convolution + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in up_ratios: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + WNConv1d(d_model, out_channels, kernel_size=3, padding=1), + ] + + if use_tanh: + self.block += [nn.Tanh()] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + self.reset_parameters() + + def forward(self, x): + return self.block(x) + + def reset_parameters(self): + self.apply(init_weights) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + output_padding=stride % 2, + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class CodecDecoder(nn.Module): + def __init__( + self, + in_channels: int = 256, + upsample_initial_channel: int = 1536, + up_ratios: list = [5, 5, 4, 2], + num_quantizers: int = 8, + codebook_size: int = 1024, + codebook_dim: int = 256, + quantizer_type: str = "vq", + quantizer_dropout: float = 0.5, + commitment: float = 0.25, + codebook_loss_weight: float = 1.0, + use_l2_normlize: bool = False, + codebook_type: str = "euclidean", + kmeans_init: bool = False, + kmeans_iters: int = 10, + decay: float = 0.8, + eps: float = 1e-5, + threshold_ema_dead_code: int = 2, + weight_init: bool = False, + use_vocos: bool = False, + vocos_dim: int = 384, + vocos_intermediate_dim: int = 1152, + vocos_num_layers: int = 8, + n_fft: int = 800, + hop_size: int = 200, + padding: str = "same", + cfg=None, + ): + super().__init__() + + in_channels = ( + cfg.in_channels + if cfg is not None and hasattr(cfg, "in_channels") + else in_channels + ) + upsample_initial_channel = ( + cfg.upsample_initial_channel + if cfg is not None and hasattr(cfg, "upsample_initial_channel") + else upsample_initial_channel + ) + up_ratios = ( + cfg.up_ratios + if cfg is not None and hasattr(cfg, "up_ratios") + else up_ratios + ) + num_quantizers = ( + cfg.num_quantizers + if cfg is not None and hasattr(cfg, "num_quantizers") + else num_quantizers + ) + codebook_size = ( + cfg.codebook_size + if cfg is not None and hasattr(cfg, "codebook_size") + else codebook_size + ) + codebook_dim = ( + cfg.codebook_dim + if cfg is not None and hasattr(cfg, "codebook_dim") + else codebook_dim + ) + quantizer_type = ( + cfg.quantizer_type + if cfg is not None and hasattr(cfg, "quantizer_type") + else quantizer_type + ) + quantizer_dropout = ( + cfg.quantizer_dropout + if cfg is not None and hasattr(cfg, "quantizer_dropout") + else quantizer_dropout + ) + commitment = ( + cfg.commitment + if cfg is not None and hasattr(cfg, "commitment") + else commitment + ) + codebook_loss_weight = ( + cfg.codebook_loss_weight + if cfg is not None and hasattr(cfg, "codebook_loss_weight") + else codebook_loss_weight + ) + use_l2_normlize = ( + cfg.use_l2_normlize + if cfg is not None and hasattr(cfg, "use_l2_normlize") + else use_l2_normlize + ) + codebook_type = ( + cfg.codebook_type + if cfg is not None and hasattr(cfg, "codebook_type") + else codebook_type + ) + kmeans_init = ( + cfg.kmeans_init + if cfg is not None and hasattr(cfg, "kmeans_init") + else kmeans_init + ) + kmeans_iters = ( + cfg.kmeans_iters + if cfg is not None and hasattr(cfg, "kmeans_iters") + else kmeans_iters + ) + decay = cfg.decay if cfg is not None and hasattr(cfg, "decay") else decay + eps = cfg.eps if cfg is not None and hasattr(cfg, "eps") else eps + threshold_ema_dead_code = ( + cfg.threshold_ema_dead_code + if cfg is not None and hasattr(cfg, "threshold_ema_dead_code") + else threshold_ema_dead_code + ) + weight_init = ( + cfg.weight_init + if cfg is not None and hasattr(cfg, "weight_init") + else weight_init + ) + use_vocos = ( + cfg.use_vocos + if cfg is not None and hasattr(cfg, "use_vocos") + else use_vocos + ) + vocos_dim = ( + cfg.vocos_dim + if cfg is not None and hasattr(cfg, "vocos_dim") + else vocos_dim + ) + vocos_intermediate_dim = ( + cfg.vocos_intermediate_dim + if cfg is not None and hasattr(cfg, "vocos_intermediate_dim") + else vocos_intermediate_dim + ) + vocos_num_layers = ( + cfg.vocos_num_layers + if cfg is not None and hasattr(cfg, "vocos_num_layers") + else vocos_num_layers + ) + n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft + hop_size = ( + cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size + ) + padding = ( + cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding + ) + + if quantizer_type == "vq": + self.quantizer = ResidualVQ( + input_dim=in_channels, + num_quantizers=num_quantizers, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_type=quantizer_type, + quantizer_dropout=quantizer_dropout, + commitment=commitment, + codebook_loss_weight=codebook_loss_weight, + use_l2_normlize=use_l2_normlize, + codebook_type=codebook_type, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + eps=eps, + threshold_ema_dead_code=threshold_ema_dead_code, + weight_init=weight_init, + ) + elif quantizer_type == "fvq": + self.quantizer = ResidualVQ( + input_dim=in_channels, + num_quantizers=num_quantizers, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_type=quantizer_type, + quantizer_dropout=quantizer_dropout, + commitment=commitment, + codebook_loss_weight=codebook_loss_weight, + use_l2_normlize=use_l2_normlize, + ) + elif quantizer_type == "lfq": + self.quantizer = ResidualVQ( + input_dim=in_channels, + num_quantizers=num_quantizers, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_type=quantizer_type, + ) + else: + raise ValueError(f"Unknown quantizer type {quantizer_type}") + + if not use_vocos: + # Add first conv layer + channels = upsample_initial_channel + layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(up_ratios): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, 1, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + if use_vocos: + self.model = Vocos( + input_channels=in_channels, + dim=vocos_dim, + intermediate_dim=vocos_intermediate_dim, + num_layers=vocos_num_layers, + adanorm_num_embeddings=None, + n_fft=n_fft, + hop_size=hop_size, + padding=padding, + ) + + self.reset_parameters() + + def forward(self, x=None, vq=False, eval_vq=False, n_quantizers=None): + """ + if vq is True, x = encoder output, then return quantized output; + else, x = quantized output, then return decoder output + """ + if vq is True: + if eval_vq: + self.quantizer.eval() + ( + quantized_out, + all_indices, + all_commit_losses, + all_codebook_losses, + all_quantized, + ) = self.quantizer(x, n_quantizers=n_quantizers) + return ( + quantized_out, + all_indices, + all_commit_losses, + all_codebook_losses, + all_quantized, + ) + + return self.model(x) + + def quantize(self, x, n_quantizers=None): + self.quantizer.eval() + quantized_out, vq, _, _, _ = self.quantizer(x, n_quantizers=n_quantizers) + return quantized_out, vq + + # TODO: check consistency of vq2emb and quantize + def vq2emb(self, vq, n_quantizers=None): + return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers) + + def decode(self, x): + return self.model(x) + + def latent2dist(self, x, n_quantizers=None): + return self.quantizer.latent2dist(x, n_quantizers=n_quantizers) + + def reset_parameters(self): + self.apply(init_weights) diff --git a/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py b/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2cf45a4b1b8ae0a2590cd37d0dd098895a339f --- /dev/null +++ b/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from indextts.utils.maskgct.models.codec.amphion_codec.quantize.factorized_vector_quantize import ( + FactorizedVectorQuantize, +) +from indextts.utils.maskgct.models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize +from indextts.utils.maskgct.models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize +from indextts.utils.maskgct.models.codec.amphion_codec.quantize.residual_vq import ResidualVQ diff --git a/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py b/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..5c359f8cb60e8d2617a28f8d20806b2dbfd1b588 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py @@ -0,0 +1,150 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +class FactorizedVectorQuantize(nn.Module): + def __init__( + self, + input_dim, + codebook_size, + codebook_dim, + commitment=0.005, + codebook_loss_weight=1.0, + use_l2_normlize=True, + ): + super().__init__() + self.input_dim = input_dim + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.commitment = commitment + self.codebook_loss_weight = codebook_loss_weight + self.use_l2_normlize = use_l2_normlize + + if self.input_dim != self.codebook_dim: + self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) + self.out_project = WNConv1d( + self.codebook_dim, self.input_dim, kernel_size=1 + ) + + else: + self.in_project = nn.Identity() + self.out_project = nn.Identity() + + self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim) + + def forward(self, z): + """ + Parameters + ---------- + z: torch.Tensor[B x D x T] + + Returns + ------- + z_q: torch.Tensor[B x D x T] + Quantized continuous representation of input + commit_loss: Tensor[B] + Commitment loss to train encoder to predict vectors closer to codebook entries + codebook_loss: Tensor[B] + Codebook loss to update the codebook + indices: torch.Tensor[B x T] + Codebook indices (quantized discrete representation of input) + z_e: torch.Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim + z_e = self.in_project(z) + z_q, indices = self.decode_latents(z_e) + + # Compute commitment loss and codebook loss + if self.training: + commit_loss = ( + F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + * self.commitment + ) + codebook_loss = ( + F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + * self.codebook_loss_weight + ) + else: + commit_loss = torch.zeros(z.shape[0], device=z.device) + codebook_loss = torch.zeros(z.shape[0], device=z.device) + + z_q = z_e + (z_q - z_e).detach() + + z_q = self.out_project(z_q) + + return z_q, commit_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight + + # L2 normalize encodings and codebook + if self.use_l2_normlize: + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance between encodings and codebook, + # if use_l2_normlize is True, the distance is equal to cosine distance + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + + return z_q, indices + + def vq2emb(self, vq, out_proj=True): + emb = self.decode_code(vq) + if out_proj: + emb = self.out_project(emb) + return emb + + def latent2dist(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight + + # L2 normalize encodings and codebook + if self.use_l2_normlize: + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance between encodings and codebook, + # if use_l2_normlize is True, the distance is equal to cosine distance + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) # (b*t, k) + + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + dist = rearrange(dist, "(b t) k -> b t k", b=latents.size(0)) + z_q = self.decode_code(indices) + + return -dist, indices, z_q diff --git a/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py b/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1801573bc8f5935cce465a48bd6e5192953e32 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py @@ -0,0 +1,77 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +class LookupFreeQuantize(nn.Module): + def __init__( + self, + input_dim, + codebook_size, + codebook_dim, + ): + super().__init__() + self.input_dim = input_dim + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + assert 2**codebook_dim == codebook_size + + if self.input_dim != self.codebook_dim: + self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) + self.out_project = WNConv1d( + self.codebook_dim, self.input_dim, kernel_size=1 + ) + + else: + self.in_project = nn.Identity() + self.out_project = nn.Identity() + + def forward(self, z): + z_e = self.in_project(z) + z_e = F.sigmoid(z_e) + + z_q = z_e + (torch.round(z_e) - z_e).detach() + + z_q = self.out_project(z_q) + + commit_loss = torch.zeros(z.shape[0], device=z.device) + codebook_loss = torch.zeros(z.shape[0], device=z.device) + + bits = ( + 2 + ** torch.arange(self.codebook_dim, device=z.device) + .unsqueeze(0) + .unsqueeze(-1) + .long() + ) # (1, d, 1) + indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long() + + return z_q, commit_loss, codebook_loss, indices, z_e + + def vq2emb(self, vq, out_proj=True): + emb = torch.zeros( + vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device + ) # (B, d, T) + for i in range(self.codebook_dim): + emb[:, i, :] = (vq % 2).float() + vq = vq // 2 + if out_proj: + emb = self.out_project(emb) + return emb diff --git a/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py b/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..d3b5538f2774e4e9326025482ed6b524b611a72b --- /dev/null +++ b/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py @@ -0,0 +1,177 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from indextts.utils.maskgct.models.codec.amphion_codec.quantize.factorized_vector_quantize import ( + FactorizedVectorQuantize, +) +from indextts.utils.maskgct.models.codec.amphion_codec.quantize.vector_quantize import VectorQuantize +from indextts.utils.maskgct.models.codec.amphion_codec.quantize.lookup_free_quantize import LookupFreeQuantize + + +class ResidualVQ(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 256, + num_quantizers: int = 8, + codebook_size: int = 1024, + codebook_dim: int = 256, + quantizer_type: str = "vq", # "vq" or "fvq" or "lfq" + quantizer_dropout: float = 0.5, + **kwargs, + ): + super().__init__() + + self.input_dim = input_dim + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer_type = quantizer_type + self.quantizer_dropout = quantizer_dropout + + if quantizer_type == "vq": + VQ = VectorQuantize + elif quantizer_type == "fvq": + VQ = FactorizedVectorQuantize + elif quantizer_type == "lfq": + VQ = LookupFreeQuantize + else: + raise ValueError(f"Unknown quantizer type {quantizer_type}") + + self.quantizers = nn.ModuleList( + [ + VQ( + input_dim=input_dim, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + **kwargs, + ) + for _ in range(num_quantizers) + ] + ) + + def forward(self, z, n_quantizers: int = None): + """ + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + "quantized_out" : Tensor[B x D x T] + Quantized continuous representation of input + "all_indices" : Tensor[N x B x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "all_commit_losses" : Tensor[N] + "all_codebook_losses" : Tensor[N] + "all_quantized" : Tensor[N x B x D x T] + """ + + quantized_out = 0.0 + residual = z + + all_commit_losses = [] + all_codebook_losses = [] + all_indices = [] + all_quantized = [] + + if n_quantizers is None: + n_quantizers = self.num_quantizers + + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.num_quantizers + 1 + dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commit_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + quantized_out = quantized_out + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + commit_loss_i = (commit_loss_i * mask).mean() + codebook_loss_i = (codebook_loss_i * mask).mean() + + all_commit_losses.append(commit_loss_i) + all_codebook_losses.append(codebook_loss_i) + all_indices.append(indices_i) + all_quantized.append(z_q_i) + + all_commit_losses, all_codebook_losses, all_indices, all_quantized = map( + torch.stack, + (all_commit_losses, all_codebook_losses, all_indices, all_quantized), + ) + + return ( + quantized_out, + all_indices, + all_commit_losses, + all_codebook_losses, + all_quantized, + ) + + def vq2emb(self, vq, n_quantizers=None): + quantized_out = 0.0 + if n_quantizers is None: + n_quantizers = self.num_quantizers + for idx, quantizer in enumerate(self.quantizers): + if idx >= n_quantizers: + break + quantized_out += quantizer.vq2emb(vq[idx]) + return quantized_out + + def latent2dist(self, z, n_quantizers=None): + quantized_out = 0.0 + residual = z + + all_dists = [] + all_indices = [] + + if n_quantizers is None: + n_quantizers = self.num_quantizers + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + dist_i, indices_i, z_q_i = quantizer.latent2dist(residual) + all_dists.append(dist_i) + all_indices.append(indices_i) + + quantized_out = quantized_out + z_q_i + residual = residual - z_q_i + + all_dists = torch.stack(all_dists) + all_indices = torch.stack(all_indices) + + return all_dists, all_indices diff --git a/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py b/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..8296893ac7422982cb4a794418f73d7d57c18c98 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py @@ -0,0 +1,401 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +def ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories, eps=1e-5): + return (x + eps) / (x.sum() + n_categories * eps) + + +def sample_vectors(samples, num): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): + dim, dtype, device = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, "n d -> n () d") - rearrange( + means, "c d -> () c d" + ) + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = l2norm(new_means) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + def __init__( + self, + dim, + codebook_size, + kmeans_init=False, + kmeans_iters=10, + decay=0.8, + eps=1e-5, + threshold_ema_dead_code=2, + weight_init=False, + ): + super().__init__() + + self.decay = decay + init_fn = torch.randn if not weight_init else torch.zeros + embed = init_fn(codebook_size, dim) + + if weight_init: + nn.init.uniform_(embed, -1 / codebook_size, 1 / codebook_size) + + self.codebook_size = codebook_size + self.kmeans_iters = kmeans_iters + self.eps = eps + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer( + "initted", torch.Tensor([not kmeans_init]) + ) # if kmeans_init is True, then initted is False; otherwise, initted is True + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + def init_embed_(self, data): + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed) + self.cluster_size.data.copy_(cluster_size) + self.initted.data.copy_(torch.Tensor([True])) + + def replace(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace(batch_samples, mask=expired_codes) + + def forward(self, x): + shape, dtype = x.shape, x.dtype + flatten = rearrange(x, "... d -> (...) d") + embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size) + + if not self.initted: + self.init_embed_(flatten) + + dist = -( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + + embed_ind = dist.max(dim=-1).indices + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = embed_ind.view(*shape[:-1]) + quantize = F.embedding(embed_ind, self.embed) + + if self.training: + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = ( + flatten.t() @ embed_onehot + ) # (dim, ...) @ (..., codebook_size) -> (dim, codebook_size) + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + self.expire_codes_(x) + + return quantize, embed_ind + + def vq2emb(self, vq): + quantize = F.embedding(vq, self.embed) + return quantize + + def latent2dist(self, x): + shape, dtype = x.shape, x.dtype + flatten = rearrange(x, "... d -> (...) d") + embed = self.embed.t() # (codebook_size, dim) -> (dim, codebook_size) + + if not self.initted: + self.init_embed_(flatten) + + dist = -( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + + embed_ind = dist.max(dim=-1).indices + embed_ind = embed_ind.view(*shape[:-1]) + quantize = F.embedding(embed_ind, self.embed) + + dist = dist.view(*shape[:-1], -1) + + return dist, embed_ind, quantize + + +class SimpleCodebook(nn.Module): + def __init__( + self, + dim, + codebook_size, + use_l2_normlize=False, + ): + super().__init__() + + self.dim = dim + self.codebook_size = codebook_size + self.use_l2_normlize = use_l2_normlize + + self.embed = nn.Embedding(self.codebook_size, self.dim) + + def forward(self, x): + shape, dtype = x.shape, x.dtype + flatten = rearrange(x, "... d -> (...) d") + embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size) + + if self.use_l2_normlize: + flatten = F.normalize(flatten) + embed = F.normalize(embed) + + dist = -( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + + embed_ind = dist.max(dim=-1).indices + embed_ind = embed_ind.view(*shape[:-1]) + quantize = F.embedding(embed_ind, self.embed) + + return quantize, embed_ind + + def vq2emb(self, vq): + quantize = F.embedding(vq, self.embed.weight) + return quantize + + def latent2dist(self, x): + shape, dtype = x.shape, x.dtype + flatten = rearrange(x, "... d -> (...) d") + embed = self.embed.weight.t() # (codebook_size, dim) -> (dim, codebook_size) + + if self.use_l2_normlize: + flatten = F.normalize(flatten) + embed = F.normalize(embed) + + dist = -( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + + embed_ind = dist.max(dim=-1).indices + embed_ind = embed_ind.view(*shape[:-1]) + quantize = F.embedding(embed_ind, self.embed) + + dist = dist.view(*shape[:-1], -1) + + return dist, embed_ind, quantize + + +class VectorQuantize(nn.Module): + """Vector quantization and factorized vecotor quantization implementation + Args: + input_dim (int): Dimension of input. + codebook_size (int): Codebook size. + codebook_dim (int): Codebook dimension. We suggest use codebook_dim = input_dim + if use codebook_type == "euclidean", otherwise, if you want to use + factorized vector quantization, use codebook_dim as small number (e.g. 8 or 32). + commitment (float): Weight for commitment loss. + use_l2_normlize (bool): Whether to use l2 normlized codes for factorized vecotor quantization, + we suggest use it as True if you want to use factorized vector quantization + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + input_dim, + codebook_size, + codebook_dim, + commitment=0.005, + codebook_loss_weight=1.0, + use_l2_normlize=False, + codebook_type="euclidean", # "euclidean" or "simple" + kmeans_init=False, + kmeans_iters=10, + decay=0.8, + eps=1e-5, + threshold_ema_dead_code=2, + weight_init=False, + ): + super().__init__() + self.input_dim = input_dim + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.commitment = commitment + self.codebook_loss_weight = codebook_loss_weight + self.use_l2_normlize = use_l2_normlize + self.codebook_type = codebook_type + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.decay = decay + self.eps = eps + self.threshold_ema_dead_code = threshold_ema_dead_code + self.weight_init = weight_init + + if self.input_dim != self.codebook_dim: + self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) + self.out_project = WNConv1d( + self.codebook_dim, self.input_dim, kernel_size=1 + ) + + else: + self.in_project = nn.Identity() + self.out_project = nn.Identity() + + if self.codebook_type == "euclidean": + self.codebook = EuclideanCodebook( + self.codebook_dim, + codebook_size=self.codebook_size, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + decay=self.decay, + eps=self.eps, + threshold_ema_dead_code=self.threshold_ema_dead_code, + weight_init=self.weight_init, + ) + elif self.codebook_type == "simple": + self.codebook = SimpleCodebook( + self.codebook_dim, + codebook_size=self.codebook_size, + use_l2_normlize=self.use_l2_normlize, + ) + else: + raise NotImplementedError( + f"codebook_type {self.codebook_type} is not implemented!" + ) + + def forward(self, z): + """ + Parameters + ---------- + z: torch.Tensor[B x D x T] + + Returns + ------- + z_q: torch.Tensor[B x D x T] + Quantized continuous representation of input + commit_loss: Tensor[B] + Commitment loss to train encoder to predict vectors closer to codebook entries + codebook_loss: Tensor[B] + Codebook loss to update the codebook + indices: torch.Tensor[B x T] + Codebook indices (quantized discrete representation of input) + z_e: torch.Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim + z_e = self.in_project(z) + z_q, indices = self.decode_latents(z_e) + + # Compute commitment loss and codebook loss + if self.training: + commit_loss = ( + F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + * self.commitment + ) + codebook_loss = ( + F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + * self.codebook_loss_weight + ) + else: + commit_loss = torch.zeros(z.shape[0], device=z.device) + codebook_loss = torch.zeros(z.shape[0], device=z.device) + + z_q = z_e + (z_q - z_e).detach() + + z_q = self.out_project(z_q) + + return z_q, commit_loss, codebook_loss, indices, z_e + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> b t d") + z_q, indices = self.codebook(encodings) + z_q = z_q.transpose(1, 2) + return z_q, indices + + def vq2emb(self, vq, out_proj=True): + emb = self.codebook.vq2emb(vq) + emb = emb.transpose(1, 2) + if out_proj: + emb = self.out_project(emb) + return emb + + def latent2dist(self, latents): + latents = rearrange(latents, "b d t -> b t d") + dist, embed_ind, quantize = self.codebook.latent2dist(latents) + return dist, embed_ind, quantize.transpose(1, 2) diff --git a/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py b/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py new file mode 100644 index 0000000000000000000000000000000000000000..038d8ef4fd932d22c7704fca07f89ab675637ced --- /dev/null +++ b/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py @@ -0,0 +1,881 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import numpy as np +import scipy +import torch +from torch import nn, view_as_real, view_as_complex +from torch import nn +from torch.nn.utils import weight_norm, remove_weight_norm +from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz +import librosa + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) + + +def symlog(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.log1p(x.abs()) + + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) + + +class STFT(nn.Module): + def __init__( + self, + n_fft: int, + hop_length: int, + win_length: int, + center=True, + ): + super().__init__() + self.center = center + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, T * hop_length) + + if not self.center: + pad = self.win_length - self.hop_length + x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect") + + stft_spec = torch.stft( + x, + self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=self.center, + return_complex=False, + ) # (B, n_fft // 2 + 1, T, 2) + + rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2) + imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2) + + log_mag = torch.log( + torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5 + ) # (B, n_fft // 2 + 1, T) + phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T) + + return log_mag, phase + + +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__( + self, n_fft: int, hop_length: int, win_length: int, padding: str = "same" + ): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft( + spec, + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + ) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + assert (window_envelope > 1e-11).all() + y = y / window_envelope + + return y + + +class MDCT(nn.Module): + """ + Modified Discrete Cosine Transform (MDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len) + post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N) + # view_as_real: NCCL Backend does not support ComplexFloat data type + # https://github.com/pytorch/pytorch/issues/71613 + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + """ + Apply the Modified Discrete Cosine Transform (MDCT) to the input audio. + + Args: + audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size + and T is the length of the audio. + + Returns: + Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames + and N is the number of frequency bins. + """ + if self.padding == "center": + audio = torch.nn.functional.pad( + audio, (self.frame_len // 2, self.frame_len // 2) + ) + elif self.padding == "same": + # hop_length is 1/2 frame_len + audio = torch.nn.functional.pad( + audio, (self.frame_len // 4, self.frame_len // 4) + ) + else: + raise ValueError("Padding must be 'center' or 'same'.") + + x = audio.unfold(-1, self.frame_len, self.frame_len // 2) + N = self.frame_len // 2 + x = x * self.window.expand(x.shape) + X = torch.fft.fft( + x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1 + )[..., :N] + res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N) + return torch.real(res) * np.sqrt(2) + + +class IMDCT(nn.Module): + """ + Inverse Modified Discrete Cosine Transform (IMDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N) + post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2)) + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + """ + Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients. + + Args: + X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size, + L is the number of frames, and N is the number of frequency bins. + + Returns: + Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio. + """ + B, L, N = X.shape + Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device) + Y[..., :N] = X + Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,))) + y = torch.fft.ifft( + Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1 + ) + y = ( + torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) + * np.sqrt(N) + * np.sqrt(2) + ) + result = y * self.window.expand(y.shape) + output_size = (1, (L + 1) * N) + audio = torch.nn.functional.fold( + result.transpose(1, 2), + output_size=output_size, + kernel_size=(1, self.frame_len), + stride=(1, self.frame_len // 2), + )[:, 0, 0, :] + + if self.padding == "center": + pad = self.frame_len // 2 + elif self.padding == "same": + pad = self.frame_len // 4 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + audio = audio[:, pad:-pad] + return audio + + +class FourierHead(nn.Module): + """Base class for inverse fourier modules.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class ISTFTHead(FourierHead): + """ + ISTFT Head module for predicting STFT complex coefficients. + + Args: + dim (int): Hidden dimension of the model. + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames, which should align with + the resolution of the input features. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): + super().__init__() + out_dim = n_fft + 2 + self.out = torch.nn.Linear(dim, out_dim) + self.istft = ISTFT( + n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ISTFTHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x).transpose(1, 2) + mag, p = x.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip( + mag, max=1e2 + ) # safeguard to prevent excessively large magnitudes + # wrapping happens here. These two lines produce real and imaginary value + x = torch.cos(p) + y = torch.sin(p) + # recalculating phase here does not produce anything new + # only costs time + # phase = torch.atan2(y, x) + # S = mag * torch.exp(phase * 1j) + # better directly produce the complex value + S = mag * (x + 1j * y) + audio = self.istft(S) + return audio + + +class IMDCTSymExpHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with symmetric exponential function + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized + based on perceptual scaling. Defaults to None. + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__( + self, + dim: int, + mdct_frame_len: int, + padding: str = "same", + sample_rate: Optional[int] = None, + clip_audio: bool = False, + ): + super().__init__() + out_dim = mdct_frame_len // 2 + self.out = nn.Linear(dim, out_dim) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + self.clip_audio = clip_audio + + if sample_rate is not None: + # optionally init the last layer following mel-scale + m_max = _hz_to_mel(sample_rate // 2) + m_pts = torch.linspace(0, m_max, out_dim) + f_pts = _mel_to_hz(m_pts) + scale = 1 - (f_pts / f_pts.max()) + + with torch.no_grad(): + self.out.weight.mul_(scale.view(-1, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTSymExpHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + x = symexp(x) + x = torch.clip( + x, min=-1e2, max=1e2 + ) # safeguard to prevent excessively large magnitudes + audio = self.imdct(x) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + + return audio + + +class IMDCTCosHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p) + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__( + self, + dim: int, + mdct_frame_len: int, + padding: str = "same", + clip_audio: bool = False, + ): + super().__init__() + self.clip_audio = clip_audio + self.out = nn.Linear(dim, mdct_frame_len) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTCosHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + m, p = x.chunk(2, dim=2) + m = torch.exp(m).clip( + max=1e2 + ) # safeguard to prevent excessively large magnitudes + audio = self.imdct(m * torch.cos(p)) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + return audio + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: float, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward( + self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None + ) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + num_embeddings (int): Number of embeddings. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim + ) + self.shift = nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim + ) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding_id) + shift = self.shift(cond_embedding_id) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale + shift + return x + + +class ResBlock1(nn.Module): + """ + ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, + but without upsampling layers. + + Args: + dim (int): Number of input channels. + kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. + dilation (tuple[int], optional): Dilation factors for the dilated convolutions. + Defaults to (1, 3, 5). + lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. + Defaults to 0.1. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + dilation: Tuple[int, int, int] = (1, 3, 5), + lrelu_slope: float = 0.1, + layer_scale_init_value: Optional[float] = None, + ): + super().__init__() + self.lrelu_slope = lrelu_slope + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[0], + padding=self.get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[1], + padding=self.get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[2], + padding=self.get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + ), + ] + ) + + self.gamma = nn.ParameterList( + [ + ( + nn.Parameter( + layer_scale_init_value * torch.ones(dim, 1), requires_grad=True + ) + if layer_scale_init_value is not None + else None + ), + ( + nn.Parameter( + layer_scale_init_value * torch.ones(dim, 1), requires_grad=True + ) + if layer_scale_init_value is not None + else None + ), + ( + nn.Parameter( + layer_scale_init_value * torch.ones(dim, 1), requires_grad=True + ) + if layer_scale_init_value is not None + else None + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): + xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) + xt = c2(xt) + if gamma is not None: + xt = gamma * xt + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + @staticmethod + def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +class Backbone(nn.Module): + """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + + Returns: + Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, + and H denotes the model dimension. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class VocosBackbone(Backbone): + """ + Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. + num_layers (int): Number of ConvNeXtBlock layers. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional model. Defaults to None. + """ + + def __init__( + self, + input_channels: int, + dim: int, + intermediate_dim: int, + num_layers: int, + layer_scale_init_value: Optional[float] = None, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + layer_scale_init_value = layer_scale_init_value or 1 / num_layers + self.convnext = nn.ModuleList( + [ + ConvNeXtBlock( + dim=dim, + intermediate_dim=intermediate_dim, + layer_scale_init_value=layer_scale_init_value, + adanorm_num_embeddings=adanorm_num_embeddings, + ) + for _ in range(num_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + bandwidth_id = kwargs.get("bandwidth_id", None) + x = self.embed(x) + if self.adanorm: + assert bandwidth_id is not None + x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) + else: + x = self.norm(x.transpose(1, 2)) + x = x.transpose(1, 2) + for conv_block in self.convnext: + x = conv_block(x, cond_embedding_id=bandwidth_id) + x = self.final_layer_norm(x.transpose(1, 2)) + return x + + +class VocosResNetBackbone(Backbone): + """ + Vocos backbone module built with ResBlocks. + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + num_blocks (int): Number of ResBlock1 blocks. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. + """ + + def __init__( + self, + input_channels, + dim, + num_blocks, + layer_scale_init_value=None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = weight_norm( + nn.Conv1d(input_channels, dim, kernel_size=3, padding=1) + ) + layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 + self.resnet = nn.Sequential( + *[ + ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) + for _ in range(num_blocks) + ] + ) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.embed(x) + x = self.resnet(x) + x = x.transpose(1, 2) + return x + + +class Vocos(nn.Module): + def __init__( + self, + input_channels: int = 256, + dim: int = 384, + intermediate_dim: int = 1152, + num_layers: int = 8, + n_fft: int = 800, + hop_size: int = 200, + padding: str = "same", + adanorm_num_embeddings=None, + cfg=None, + ): + super().__init__() + + input_channels = ( + cfg.input_channels + if cfg is not None and hasattr(cfg, "input_channels") + else input_channels + ) + dim = cfg.dim if cfg is not None and hasattr(cfg, "dim") else dim + intermediate_dim = ( + cfg.intermediate_dim + if cfg is not None and hasattr(cfg, "intermediate_dim") + else intermediate_dim + ) + num_layers = ( + cfg.num_layers + if cfg is not None and hasattr(cfg, "num_layers") + else num_layers + ) + adanorm_num_embeddings = ( + cfg.adanorm_num_embeddings + if cfg is not None and hasattr(cfg, "adanorm_num_embeddings") + else adanorm_num_embeddings + ) + n_fft = cfg.n_fft if cfg is not None and hasattr(cfg, "n_fft") else n_fft + hop_size = ( + cfg.hop_size if cfg is not None and hasattr(cfg, "hop_size") else hop_size + ) + padding = ( + cfg.padding if cfg is not None and hasattr(cfg, "padding") else padding + ) + + self.backbone = VocosBackbone( + input_channels=input_channels, + dim=dim, + intermediate_dim=intermediate_dim, + num_layers=num_layers, + adanorm_num_embeddings=adanorm_num_embeddings, + ) + self.head = ISTFTHead(dim, n_fft, hop_size, padding) + + def forward(self, x): + x = self.backbone(x) + x = self.head(x) + + return x[:, None, :] diff --git a/indextts/utils/maskgct/models/codec/codec_dataset.py b/indextts/utils/maskgct/models/codec/codec_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..be0a30856a9ce8183a15013f71965b2f010647b4 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/codec_dataset.py @@ -0,0 +1,264 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Iterable +import torch +import numpy as np +import torch.utils.data +from torch.nn.utils.rnn import pad_sequence +from utils.data_utils import * +from torch.utils.data import ConcatDataset, Dataset + + +class CodecDataset(torch.utils.data.Dataset): + def __init__(self, cfg, dataset, is_valid=False): + """ + Args: + cfg: config + dataset: dataset name + is_valid: whether to use train or valid dataset + """ + assert isinstance(dataset, str) + + processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset) + + meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file + self.metafile_path = os.path.join(processed_data_dir, meta_file) + self.metadata = self.get_metadata() + + self.data_root = processed_data_dir + self.cfg = cfg + + if cfg.preprocess.use_audio: + self.utt2audio_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2audio_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.audio_dir, + uid + ".npy", + ) + elif cfg.preprocess.use_label: + self.utt2label_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2label_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.label_dir, + uid + ".npy", + ) + elif cfg.preprocess.use_one_hot: + self.utt2one_hot_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2one_hot_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.one_hot_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_mel: + self.utt2mel_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2mel_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.mel_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_frame_pitch: + self.utt2frame_pitch_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2frame_pitch_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.pitch_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_uv: + self.utt2uv_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + self.utt2uv_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.uv_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_amplitude_phase: + self.utt2logamp_path = {} + self.utt2pha_path = {} + self.utt2rea_path = {} + self.utt2imag_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + self.utt2logamp_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.log_amplitude_dir, + uid + ".npy", + ) + self.utt2pha_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.phase_dir, + uid + ".npy", + ) + self.utt2rea_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.real_dir, + uid + ".npy", + ) + self.utt2imag_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.imaginary_dir, + uid + ".npy", + ) + + def __getitem__(self, index): + utt_info = self.metadata[index] + + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + single_feature = dict() + + if self.cfg.preprocess.use_mel: + mel = np.load(self.utt2mel_path[utt]) + assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T] + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = mel.shape[1] + + single_feature["mel"] = mel + + if self.cfg.preprocess.use_frame_pitch: + frame_pitch = np.load(self.utt2frame_pitch_path[utt]) + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = len(frame_pitch) + + aligned_frame_pitch = align_length( + frame_pitch, single_feature["target_len"] + ) + + single_feature["frame_pitch"] = aligned_frame_pitch + + if self.cfg.preprocess.use_audio: + audio = np.load(self.utt2audio_path[utt]) + + single_feature["audio"] = audio + + return single_feature + + def get_metadata(self): + with open(self.metafile_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + + return metadata + + def get_dataset_name(self): + return self.metadata[0]["Dataset"] + + def __len__(self): + return len(self.metadata) + + +class CodecConcatDataset(ConcatDataset): + def __init__(self, datasets: Iterable[Dataset], full_audio_inference=False): + """Concatenate a series of datasets with their random inference audio merged.""" + super().__init__(datasets) + + self.cfg = self.datasets[0].cfg + + self.metadata = [] + + # Merge metadata + for dataset in self.datasets: + self.metadata += dataset.metadata + + # Merge random inference features + if full_audio_inference: + self.eval_audios = [] + self.eval_dataset_names = [] + if self.cfg.preprocess.use_mel: + self.eval_mels = [] + if self.cfg.preprocess.use_frame_pitch: + self.eval_pitchs = [] + for dataset in self.datasets: + self.eval_audios.append(dataset.eval_audio) + self.eval_dataset_names.append(dataset.get_dataset_name()) + if self.cfg.preprocess.use_mel: + self.eval_mels.append(dataset.eval_mel) + if self.cfg.preprocess.use_frame_pitch: + self.eval_pitchs.append(dataset.eval_pitch) + + +class CodecCollator(object): + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + self.cfg = cfg + + def __call__(self, batch): + packed_batch_features = dict() + + # mel: [b, n_mels, frame] + # frame_pitch: [b, frame] + # audios: [b, frame * hop_size] + + for key in batch[0].keys(): + if key == "target_len": + packed_batch_features["target_len"] = torch.LongTensor( + [b["target_len"] for b in batch] + ) + masks = [ + torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch + ] + packed_batch_features["mask"] = pad_sequence( + masks, batch_first=True, padding_value=0 + ) + elif key == "mel": + values = [torch.from_numpy(b[key]).T for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + else: + values = [torch.from_numpy(b[key]) for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + + return packed_batch_features diff --git a/indextts/utils/maskgct/models/codec/codec_inference.py b/indextts/utils/maskgct/models/codec/codec_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..95e354c5db80cbac986543fdf7923014426c5078 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/codec_inference.py @@ -0,0 +1,515 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import json +import json5 +import time +import accelerate +import random +import numpy as np +import shutil + +from pathlib import Path +from tqdm import tqdm +from glob import glob +from accelerate.logging import get_logger +from torch.utils.data import DataLoader + +from models.vocoders.vocoder_dataset import ( + VocoderDataset, + VocoderCollator, + VocoderConcatDataset, +) + +from models.vocoders.gan.generator import bigvgan, hifigan, melgan, nsfhifigan, apnet +from models.vocoders.flow.waveglow import waveglow +from models.vocoders.diffusion.diffwave import diffwave +from models.vocoders.autoregressive.wavenet import wavenet +from models.vocoders.autoregressive.wavernn import wavernn + +from models.vocoders.gan import gan_vocoder_inference +from models.vocoders.diffusion import diffusion_vocoder_inference + +from utils.io import save_audio + +_vocoders = { + "diffwave": diffwave.DiffWave, + "wavernn": wavernn.WaveRNN, + "wavenet": wavenet.WaveNet, + "waveglow": waveglow.WaveGlow, + "nsfhifigan": nsfhifigan.NSFHiFiGAN, + "bigvgan": bigvgan.BigVGAN, + "hifigan": hifigan.HiFiGAN, + "melgan": melgan.MelGAN, + "apnet": apnet.APNet, +} + +# Forward call for generalized Inferencor +_vocoder_forward_funcs = { + # "world": world_inference.synthesis_audios, + # "wavernn": wavernn_inference.synthesis_audios, + # "wavenet": wavenet_inference.synthesis_audios, + "diffwave": diffusion_vocoder_inference.vocoder_inference, + "nsfhifigan": gan_vocoder_inference.vocoder_inference, + "bigvgan": gan_vocoder_inference.vocoder_inference, + "melgan": gan_vocoder_inference.vocoder_inference, + "hifigan": gan_vocoder_inference.vocoder_inference, + "apnet": gan_vocoder_inference.vocoder_inference, +} + +# APIs for other tasks. e.g. SVC, TTS, TTA... +_vocoder_infer_funcs = { + # "world": world_inference.synthesis_audios, + # "wavernn": wavernn_inference.synthesis_audios, + # "wavenet": wavenet_inference.synthesis_audios, + "diffwave": diffusion_vocoder_inference.synthesis_audios, + "nsfhifigan": gan_vocoder_inference.synthesis_audios, + "bigvgan": gan_vocoder_inference.synthesis_audios, + "melgan": gan_vocoder_inference.synthesis_audios, + "hifigan": gan_vocoder_inference.synthesis_audios, + "apnet": gan_vocoder_inference.synthesis_audios, +} + + +class VocoderInference(object): + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + super().__init__() + + start = time.monotonic_ns() + self.args = args + self.cfg = cfg + self.infer_type = infer_type + + # Init accelerator + self.accelerator = accelerate.Accelerator() + self.accelerator.wait_for_everyone() + + # Get logger + with self.accelerator.main_process_first(): + self.logger = get_logger("inference", log_level=args.log_level) + + # Log some info + self.logger.info("=" * 56) + self.logger.info("||\t\t" + "New inference process started." + "\t\t||") + self.logger.info("=" * 56) + self.logger.info("\n") + + self.vocoder_dir = args.vocoder_dir + self.logger.debug(f"Vocoder dir: {args.vocoder_dir}") + + os.makedirs(args.output_dir, exist_ok=True) + if os.path.exists(os.path.join(args.output_dir, "pred")): + shutil.rmtree(os.path.join(args.output_dir, "pred")) + if os.path.exists(os.path.join(args.output_dir, "gt")): + shutil.rmtree(os.path.join(args.output_dir, "gt")) + os.makedirs(os.path.join(args.output_dir, "pred"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "gt"), exist_ok=True) + + # Set random seed + with self.accelerator.main_process_first(): + start = time.monotonic_ns() + self._set_random_seed(self.cfg.train.random_seed) + end = time.monotonic_ns() + self.logger.debug( + f"Setting random seed done in {(end - start) / 1e6:.2f}ms" + ) + self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") + + # Setup inference mode + if self.infer_type == "infer_from_dataset": + self.cfg.dataset = self.args.infer_datasets + elif self.infer_type == "infer_from_feature": + self._build_tmp_dataset_from_feature() + self.cfg.dataset = ["tmp"] + elif self.infer_type == "infer_from_audio": + self._build_tmp_dataset_from_audio() + self.cfg.dataset = ["tmp"] + + # Setup data loader + with self.accelerator.main_process_first(): + self.logger.info("Building dataset...") + start = time.monotonic_ns() + self.test_dataloader = self._build_dataloader() + end = time.monotonic_ns() + self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") + + # Build model + with self.accelerator.main_process_first(): + self.logger.info("Building model...") + start = time.monotonic_ns() + self.model = self._build_model() + end = time.monotonic_ns() + self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms") + + # Init with accelerate + self.logger.info("Initializing accelerate...") + start = time.monotonic_ns() + self.accelerator = accelerate.Accelerator() + (self.model, self.test_dataloader) = self.accelerator.prepare( + self.model, self.test_dataloader + ) + end = time.monotonic_ns() + self.accelerator.wait_for_everyone() + self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms") + + with self.accelerator.main_process_first(): + self.logger.info("Loading checkpoint...") + start = time.monotonic_ns() + if os.path.isdir(args.vocoder_dir): + if os.path.isdir(os.path.join(args.vocoder_dir, "checkpoint")): + self._load_model(os.path.join(args.vocoder_dir, "checkpoint")) + else: + self._load_model(os.path.join(args.vocoder_dir)) + else: + self._load_model(os.path.join(args.vocoder_dir)) + end = time.monotonic_ns() + self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms") + + self.model.eval() + self.accelerator.wait_for_everyone() + + def _build_tmp_dataset_from_feature(self): + if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")): + shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + + utts = [] + mels = glob(os.path.join(self.args.feature_folder, "mels", "*.npy")) + for i, mel in enumerate(mels): + uid = mel.split("/")[-1].split(".")[0] + utt = {"Dataset": "tmp", "Uid": uid, "index": i} + utts.append(utt) + + os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + with open( + os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w" + ) as f: + json.dump(utts, f) + + meta_info = {"dataset": "tmp", "test": {"size": len(utts)}} + + with open( + os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"), + "w", + ) as f: + json.dump(meta_info, f) + + features = glob(os.path.join(self.args.feature_folder, "*")) + for feature in features: + feature_name = feature.split("/")[-1] + if os.path.isfile(feature): + continue + shutil.copytree( + os.path.join(self.args.feature_folder, feature_name), + os.path.join(self.cfg.preprocess.processed_dir, "tmp", feature_name), + ) + + def _build_tmp_dataset_from_audio(self): + if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")): + shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + + utts = [] + audios = glob(os.path.join(self.args.audio_folder, "*")) + for i, audio in enumerate(audios): + uid = audio.split("/")[-1].split(".")[0] + utt = {"Dataset": "tmp", "Uid": uid, "index": i, "Path": audio} + utts.append(utt) + + os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + with open( + os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w" + ) as f: + json.dump(utts, f) + + meta_info = {"dataset": "tmp", "test": {"size": len(utts)}} + + with open( + os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"), + "w", + ) as f: + json.dump(meta_info, f) + + from processors import acoustic_extractor + + acoustic_extractor.extract_utt_acoustic_features_serial( + utts, os.path.join(self.cfg.preprocess.processed_dir, "tmp"), self.cfg + ) + + def _build_test_dataset(self): + return VocoderDataset, VocoderCollator + + def _build_model(self): + model = _vocoders[self.cfg.model.generator](self.cfg) + return model + + def _build_dataloader(self): + """Build dataloader which merges a series of datasets.""" + Dataset, Collator = self._build_test_dataset() + + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=True) + datasets_list.append(subdataset) + test_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=False) + test_collate = Collator(self.cfg) + test_batch_size = min(self.cfg.inference.batch_size, len(test_dataset)) + test_dataloader = DataLoader( + test_dataset, + collate_fn=test_collate, + num_workers=1, + batch_size=test_batch_size, + shuffle=False, + ) + self.test_batch_size = test_batch_size + self.test_dataset = test_dataset + return test_dataloader + + def _load_model(self, checkpoint_dir, from_multi_gpu=False): + """Load model from checkpoint. If a folder is given, it will + load the latest checkpoint in checkpoint_dir. If a path is given + it will load the checkpoint specified by checkpoint_path. + **Only use this method after** ``accelerator.prepare()``. + """ + if os.path.isdir(checkpoint_dir): + if "epoch" in checkpoint_dir and "step" in checkpoint_dir: + checkpoint_path = checkpoint_dir + else: + # Load the latest accelerator state dicts + ls = [ + str(i) + for i in Path(checkpoint_dir).glob("*") + if not "audio" in str(i) + ] + ls.sort( + key=lambda x: int(x.split("/")[-1].split("_")[0].split("-")[-1]), + reverse=True, + ) + checkpoint_path = ls[0] + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + return str(checkpoint_path) + else: + # Load old .pt checkpoints + if self.cfg.model.generator in [ + "bigvgan", + "hifigan", + "melgan", + "nsfhifigan", + ]: + ckpt = torch.load( + checkpoint_dir, + map_location=( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ), + ) + if from_multi_gpu: + pretrained_generator_dict = ckpt["generator_state_dict"] + generator_dict = self.model.state_dict() + + new_generator_dict = { + k.split("module.")[-1]: v + for k, v in pretrained_generator_dict.items() + if ( + k.split("module.")[-1] in generator_dict + and v.shape == generator_dict[k.split("module.")[-1]].shape + ) + } + + generator_dict.update(new_generator_dict) + + self.model.load_state_dict(generator_dict) + else: + self.model.load_state_dict(ckpt["generator_state_dict"]) + else: + self.model.load_state_dict(torch.load(checkpoint_dir)["state_dict"]) + return str(checkpoint_dir) + + def inference(self): + """Inference via batches""" + for i, batch in tqdm(enumerate(self.test_dataloader)): + if self.cfg.preprocess.use_frame_pitch: + audio_pred = _vocoder_forward_funcs[self.cfg.model.generator]( + self.cfg, + self.model, + batch["mel"].transpose(-1, -2), + f0s=batch["frame_pitch"].float(), + device=next(self.model.parameters()).device, + ) + else: + audio_pred = _vocoder_forward_funcs[self.cfg.model.generator]( + self.cfg, + self.model, + batch["mel"].transpose(-1, -2), + device=next(self.model.parameters()).device, + ) + audio_ls = audio_pred.chunk(self.test_batch_size) + audio_gt_ls = batch["audio"].cpu().chunk(self.test_batch_size) + length_ls = batch["target_len"].cpu().chunk(self.test_batch_size) + j = 0 + for it, it_gt, l in zip(audio_ls, audio_gt_ls, length_ls): + l = l.item() + it = it.squeeze(0).squeeze(0)[: l * self.cfg.preprocess.hop_size] + it_gt = it_gt.squeeze(0)[: l * self.cfg.preprocess.hop_size] + uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"] + save_audio( + os.path.join(self.args.output_dir, "pred", "{}.wav").format(uid), + it, + self.cfg.preprocess.sample_rate, + ) + save_audio( + os.path.join(self.args.output_dir, "gt", "{}.wav").format(uid), + it_gt, + self.cfg.preprocess.sample_rate, + ) + j += 1 + + if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")): + shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + + def _set_random_seed(self, seed): + """Set random seed for all possible random modules.""" + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + def _count_parameters(self, model): + return sum(p.numel() for p in model.parameters()) + + def _dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) + + +def load_nnvocoder( + cfg, + vocoder_name, + weights_file, + from_multi_gpu=False, +): + """Load the specified vocoder. + cfg: the vocoder config filer. + weights_file: a folder or a .pt path. + from_multi_gpu: automatically remove the "module" string in state dicts if "True". + """ + print("Loading Vocoder from Weights file: {}".format(weights_file)) + + # Build model + model = _vocoders[vocoder_name](cfg) + if not os.path.isdir(weights_file): + # Load from .pt file + if vocoder_name in ["bigvgan", "hifigan", "melgan", "nsfhifigan"]: + ckpt = torch.load( + weights_file, + map_location=( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ), + ) + if from_multi_gpu: + pretrained_generator_dict = ckpt["generator_state_dict"] + generator_dict = model.state_dict() + + new_generator_dict = { + k.split("module.")[-1]: v + for k, v in pretrained_generator_dict.items() + if ( + k.split("module.")[-1] in generator_dict + and v.shape == generator_dict[k.split("module.")[-1]].shape + ) + } + + generator_dict.update(new_generator_dict) + + model.load_state_dict(generator_dict) + else: + model.load_state_dict(ckpt["generator_state_dict"]) + else: + model.load_state_dict(torch.load(weights_file)["state_dict"]) + else: + # Load from accelerator state dict + weights_file = os.path.join(weights_file, "checkpoint") + ls = [str(i) for i in Path(weights_file).glob("*") if not "audio" in str(i)] + ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) + checkpoint_path = ls[0] + accelerator = accelerate.Accelerator() + model = accelerator.prepare(model) + accelerator.load_state(checkpoint_path) + + if torch.cuda.is_available(): + model = model.cuda() + + model = model.eval() + return model + + +def tensorize(data, device, n_samples): + """ + data: a list of numpy array + """ + assert type(data) == list + if n_samples: + data = data[:n_samples] + data = [torch.as_tensor(x, device=device) for x in data] + return data + + +def synthesis( + cfg, + vocoder_weight_file, + n_samples, + pred, + f0s=None, + batch_size=64, + fast_inference=False, +): + """Synthesis audios from a given vocoder and series of given features. + cfg: vocoder config. + vocoder_weight_file: a folder of accelerator state dict or a path to the .pt file. + pred: a list of numpy arrays. [(seq_len1, acoustic_features_dim), (seq_len2, acoustic_features_dim), ...] + """ + + vocoder_name = cfg.model.generator + + print("Synthesis audios using {} vocoder...".format(vocoder_name)) + + ###### TODO: World Vocoder Refactor ###### + # if vocoder_name == "world": + # world_inference.synthesis_audios( + # cfg, dataset_name, split, n_samples, pred, save_dir, tag + # ) + # return + + # ====== Loading neural vocoder model ====== + vocoder = load_nnvocoder( + cfg, vocoder_name, weights_file=vocoder_weight_file, from_multi_gpu=True + ) + device = next(vocoder.parameters()).device + + # ====== Inference for predicted acoustic features ====== + # pred: (frame_len, n_mels) -> (n_mels, frame_len) + mels_pred = tensorize([p.T for p in pred], device, n_samples) + print("For predicted mels, #sample = {}...".format(len(mels_pred))) + audios_pred = _vocoder_infer_funcs[vocoder_name]( + cfg, + vocoder, + mels_pred, + f0s=f0s, + batch_size=batch_size, + fast_inference=fast_inference, + ) + return audios_pred diff --git a/indextts/utils/maskgct/models/codec/codec_sampler.py b/indextts/utils/maskgct/models/codec/codec_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9d29f88a291dcf7386cadaeae0d990c8e76ebf98 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/codec_sampler.py @@ -0,0 +1,126 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import random + +from torch.utils.data import ConcatDataset, Dataset +from torch.utils.data.sampler import ( + BatchSampler, + RandomSampler, + Sampler, + SequentialSampler, +) + + +class ScheduledSampler(Sampler): + """A sampler that samples data from a given concat-dataset. + + Args: + concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets + batch_size (int): batch size + holistic_shuffle (bool): whether to shuffle the whole dataset or not + logger (logging.Logger): logger to print warning message + + Usage: + For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True: + >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]]))) + [3, 4, 5, 0, 1, 2, 6, 7, 8] + """ + + def __init__( + self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train" + ): + if not isinstance(concat_dataset, ConcatDataset): + raise ValueError( + "concat_dataset must be an instance of ConcatDataset, but got {}".format( + type(concat_dataset) + ) + ) + if not isinstance(batch_size, int): + raise ValueError( + "batch_size must be an integer, but got {}".format(type(batch_size)) + ) + if not isinstance(holistic_shuffle, bool): + raise ValueError( + "holistic_shuffle must be a boolean, but got {}".format( + type(holistic_shuffle) + ) + ) + + self.concat_dataset = concat_dataset + self.batch_size = batch_size + self.holistic_shuffle = holistic_shuffle + + affected_dataset_name = [] + affected_dataset_len = [] + for dataset in concat_dataset.datasets: + dataset_len = len(dataset) + dataset_name = dataset.get_dataset_name() + if dataset_len < batch_size: + affected_dataset_name.append(dataset_name) + affected_dataset_len.append(dataset_len) + + self.type = type + for dataset_name, dataset_len in zip( + affected_dataset_name, affected_dataset_len + ): + if not type == "valid": + logger.warning( + "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format( + type, dataset_name, dataset_len, batch_size + ) + ) + + def __len__(self): + # the number of batches with drop last + num_of_batches = sum( + [ + math.floor(len(dataset) / self.batch_size) + for dataset in self.concat_dataset.datasets + ] + ) + return num_of_batches * self.batch_size + + def __iter__(self): + iters = [] + for dataset in self.concat_dataset.datasets: + iters.append( + SequentialSampler(dataset).__iter__() + if self.holistic_shuffle + else RandomSampler(dataset).__iter__() + ) + init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1] + output_batches = [] + for dataset_idx in range(len(self.concat_dataset.datasets)): + cur_batch = [] + for idx in iters[dataset_idx]: + cur_batch.append(idx + init_indices[dataset_idx]) + if len(cur_batch) == self.batch_size: + output_batches.append(cur_batch) + cur_batch = [] + if self.type == "valid" and len(cur_batch) > 0: + output_batches.append(cur_batch) + cur_batch = [] + # force drop last in training + random.shuffle(output_batches) + output_indices = [item for sublist in output_batches for item in sublist] + return iter(output_indices) + + +def build_samplers(concat_dataset: Dataset, cfg, logger, type): + sampler = ScheduledSampler( + concat_dataset, + cfg.train.batch_size, + cfg.train.sampler.holistic_shuffle, + logger, + type, + ) + batch_sampler = BatchSampler( + sampler, + cfg.train.batch_size, + cfg.train.sampler.drop_last if not type == "valid" else False, + ) + return sampler, batch_sampler diff --git a/indextts/utils/maskgct/models/codec/codec_trainer.py b/indextts/utils/maskgct/models/codec/codec_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8a6f838814c194b9d3bccfd5c8e66ea5881a33c6 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/codec_trainer.py @@ -0,0 +1,166 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import random +from pathlib import Path +import re + +import accelerate +import json5 +import numpy as np +import torch +from accelerate.utils import ProjectConfiguration +from torch.utils.data import DataLoader +from tqdm import tqdm + +from models.codec.codec_sampler import build_samplers + + +class CodecTrainer: + def __init__(self): + super().__init__() + + def _init_accelerator(self): + """Initialize the accelerator components.""" + self.exp_dir = os.path.join( + os.path.abspath(self.cfg.log_dir), self.args.exp_name + ) + project_config = ProjectConfiguration( + project_dir=self.exp_dir, logging_dir=os.path.join(self.exp_dir, "log") + ) + self.accelerator = accelerate.Accelerator( + gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step, + log_with=self.cfg.train.tracker, + project_config=project_config, + ) + if self.accelerator.is_main_process: + os.makedirs(project_config.project_dir, exist_ok=True) + os.makedirs(project_config.logging_dir, exist_ok=True) + with self.accelerator.main_process_first(): + self.accelerator.init_trackers(self.args.exp_name) + + def _build_dataset(self): + pass + + def _build_criterion(self): + pass + + def _build_model(self): + pass + + def _build_dataloader(self): + """Build dataloader which merges a series of datasets.""" + # Build dataset instance for each dataset and combine them by ConcatDataset + Dataset, Collator = self._build_dataset() + + # Build train set + train_dataset = Dataset(self.cfg, self.cfg.dataset, is_valid=False) + train_collate = Collator(self.cfg) + sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + num_replicas=self.accelerator.num_processes, + rank=self.accelerator.local_process_index, + shuffle=True, + seed=self.cfg.train.random_seed, + ) + train_loader = DataLoader( + train_dataset, + batch_size=self.cfg.train.batch_size, + collate_fn=train_collate, + sampler=sampler, + num_workers=self.cfg.train.dataloader.num_worker, + pin_memory=self.cfg.train.dataloader.pin_memory, + ) + return train_loader, None + + def _build_optimizer(self): + pass + + def _build_scheduler(self): + pass + + def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"): + """Load model from checkpoint. If a folder is given, it will + load the latest checkpoint in checkpoint_dir. If a path is given + it will load the checkpoint specified by checkpoint_path. + **Only use this method after** ``accelerator.prepare()``. + """ + if checkpoint_path is None: + ls = [str(i) for i in Path(checkpoint_dir).glob("*")] + ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) + checkpoint_path = ls[0] + if resume_type == "resume": + self.accelerator.load_state(checkpoint_path) + elif resume_type == "finetune": + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + self.logger.info("Load model weights for finetune SUCCESS!") + else: + raise ValueError("Unsupported resume type: {}".format(resume_type)) + self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1 + self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1 + return checkpoint_path + + def train_loop(self): + pass + + def _train_epoch(self): + pass + + def _valid_epoch(self): + pass + + def _train_step(self): + pass + + def _valid_step(self): + pass + + def _inference(self): + pass + + def _set_random_seed(self, seed): + """Set random seed for all possible random modules.""" + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + def _check_nan(self, loss): + if torch.any(torch.isnan(loss)): + self.logger.fatal("Fatal Error: NaN!") + self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True) + + def _check_basic_configs(self): + if self.cfg.train.gradient_accumulation_step <= 0: + self.logger.fatal("Invalid gradient_accumulation_step value!") + self.logger.error( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) + self.accelerator.end_training() + raise ValueError( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) + + def _count_parameters(self): + pass + + def _dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) + + def _is_valid_pattern(self, directory_name): + directory_name = str(directory_name) + pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}" + return re.match(pattern, directory_name) is not None diff --git a/indextts/utils/maskgct/models/codec/facodec/__init__.py b/indextts/utils/maskgct/models/codec/facodec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py b/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3bccdb97a3706bcb7149f48e04178cf00a5e877 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py @@ -0,0 +1,5 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +from .filter import * +from .resample import * +from .act import * diff --git a/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py b/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py new file mode 100644 index 0000000000000000000000000000000000000000..779d58d5f1e889f8b639dd019a0ce951e69e4cfb --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py @@ -0,0 +1,29 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch.nn as nn +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x diff --git a/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py b/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..ece8e02fce0e65e13522e990a80d1bfeeffd46ba --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py @@ -0,0 +1,96 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if "sinc" in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +def kaiser_sinc_filter1d( + cutoff, half_width, kernel_size +): # return filter [1,1,kernel_size] + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = "replicate", + kernel_size: int = 12, + ): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + # input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + return out diff --git a/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py b/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..ee993b10339141b469b67c3e11f5d73c5f4e0bca --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py @@ -0,0 +1,57 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + xx = self.lowpass(x) + + return xx diff --git a/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py b/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e86b82d1f0e6e49395dd9340961bdd517b47f8b3 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py @@ -0,0 +1,98 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import random + +import numpy as np + +import torchaudio +import librosa +from torch.nn import functional as F + +from torch.nn.utils.rnn import pad_sequence +from utils.data_utils import * +from models.codec.codec_dataset import CodecDataset + + +class FAcodecDataset(torch.utils.data.Dataset): + def __init__(self, cfg, dataset, is_valid=False): + """ + Args: + cfg: config + dataset: dataset name + is_valid: whether to use train or valid dataset + """ + self.data_root_dir = cfg.dataset + self.data_list = [] + # walk through the dataset directory recursively, save all files ends with .wav/.mp3/.opus/.flac/.m4a + for root, _, files in os.walk(self.data_root_dir): + for file in files: + if file.endswith((".wav", ".mp3", ".opus", ".flac", ".m4a")): + self.data_list.append(os.path.join(root, file)) + self.sr = cfg.preprocess_params.sr + self.duration_range = cfg.preprocess_params.duration_range + self.to_mel = torchaudio.transforms.MelSpectrogram( + n_mels=cfg.preprocess_params.spect_params.n_mels, + n_fft=cfg.preprocess_params.spect_params.n_fft, + win_length=cfg.preprocess_params.spect_params.win_length, + hop_length=cfg.preprocess_params.spect_params.hop_length, + ) + self.mean, self.std = -4, 4 + + def preprocess(self, wave): + wave_tensor = ( + torch.from_numpy(wave).float() if isinstance(wave, np.ndarray) else wave + ) + mel_tensor = self.to_mel(wave_tensor) + mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - self.mean) / self.std + return mel_tensor + + def __len__(self): + # return len(self.data_list) + return len(self.data_list) # return a fixed number for testing + + def __getitem__(self, index): + wave, _ = librosa.load(self.data_list[index], sr=self.sr) + wave = np.random.randn(self.sr * random.randint(*self.duration_range)) + wave = wave / np.max(np.abs(wave)) + mel = self.preprocess(wave).squeeze(0) + wave = torch.from_numpy(wave).float() + return wave, mel + + +class FAcodecCollator(object): + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + self.cfg = cfg + + def __call__(self, batch): + # batch[0] = wave, mel, text, f0, speakerid + batch_size = len(batch) + + # sort by mel length + lengths = [b[1].shape[1] for b in batch] + batch_indexes = np.argsort(lengths)[::-1] + batch = [batch[bid] for bid in batch_indexes] + + nmels = batch[0][1].size(0) + max_mel_length = max([b[1].shape[1] for b in batch]) + max_wave_length = max([b[0].size(0) for b in batch]) + + mels = torch.zeros((batch_size, nmels, max_mel_length)).float() - 10 + waves = torch.zeros((batch_size, max_wave_length)).float() + + mel_lengths = torch.zeros(batch_size).long() + wave_lengths = torch.zeros(batch_size).long() + + for bid, (wave, mel) in enumerate(batch): + mel_size = mel.size(1) + mels[bid, :, :mel_size] = mel + waves[bid, : wave.size(0)] = wave + mel_lengths[bid] = mel_size + wave_lengths[bid] = wave.size(0) + + return waves, mels, wave_lengths, mel_lengths diff --git a/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py b/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c494349e4c1140d9d11e9c5742a8faa7e1560705 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py @@ -0,0 +1,137 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import shutil +import warnings +import argparse +import torch +import os +import yaml + +warnings.simplefilter("ignore") + +from .modules.commons import * +import time + +import torchaudio +import librosa +from collections import OrderedDict + + +class FAcodecInference(object): + def __init__(self, args=None, cfg=None): + self.args = args + self.cfg = cfg + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = self._build_model() + self._load_checkpoint() + + def _build_model(self): + model = build_model(self.cfg.model_params) + _ = [model[key].to(self.device) for key in model] + return model + + def _load_checkpoint(self): + sd = torch.load(self.args.checkpoint_path, map_location="cpu") + sd = sd["net"] if "net" in sd else sd + new_params = dict() + for key, state_dict in sd.items(): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith("module."): + k = k[7:] + new_state_dict[k] = v + new_params[key] = new_state_dict + for key in new_params: + if key in self.model: + self.model[key].load_state_dict(new_params[key]) + _ = [self.model[key].eval() for key in self.model] + + @torch.no_grad() + def inference(self, source, output_dir): + source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0] + source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device) + + z = self.model.encoder(source_audio[None, ...].to(self.device).float()) + ( + z, + quantized, + commitment_loss, + codebook_loss, + timbre, + codes, + ) = self.model.quantizer( + z, + source_audio[None, ...].to(self.device).float(), + n_c=self.cfg.model_params.n_c_codebooks, + return_codes=True, + ) + + full_pred_wave = self.model.decoder(z) + + os.makedirs(output_dir, exist_ok=True) + source_name = source.split("/")[-1].split(".")[0] + torchaudio.save( + f"{output_dir}/reconstructed_{source_name}.wav", + full_pred_wave[0].cpu(), + self.cfg.preprocess_params.sr, + ) + + print( + "Reconstructed audio saved as: ", + f"{output_dir}/reconstructed_{source_name}.wav", + ) + + return quantized, codes + + @torch.no_grad() + def voice_conversion(self, source, reference, output_dir): + source_audio = librosa.load(source, sr=self.cfg.preprocess_params.sr)[0] + source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device) + + reference_audio = librosa.load(reference, sr=self.cfg.preprocess_params.sr)[0] + reference_audio = ( + torch.tensor(reference_audio).unsqueeze(0).float().to(self.device) + ) + + z = self.model.encoder(source_audio[None, ...].to(self.device).float()) + z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer( + z, + source_audio[None, ...].to(self.device).float(), + n_c=self.cfg.model_params.n_c_codebooks, + ) + + z_ref = self.model.encoder(reference_audio[None, ...].to(self.device).float()) + ( + z_ref, + quantized_ref, + commitment_loss_ref, + codebook_loss_ref, + timbre_ref, + ) = self.model.quantizer( + z_ref, + reference_audio[None, ...].to(self.device).float(), + n_c=self.cfg.model_params.n_c_codebooks, + ) + + z_conv = self.model.quantizer.voice_conversion( + quantized[0] + quantized[1], + reference_audio[None, ...].to(self.device).float(), + ) + full_pred_wave = self.model.decoder(z_conv) + + os.makedirs(output_dir, exist_ok=True) + source_name = source.split("/")[-1].split(".")[0] + reference_name = reference.split("/")[-1].split(".")[0] + torchaudio.save( + f"{output_dir}/converted_{source_name}_to_{reference_name}.wav", + full_pred_wave[0].cpu(), + self.cfg.preprocess_params.sr, + ) + + print( + "Voice conversion results saved as: ", + f"{output_dir}/converted_{source_name}_to_{reference_name}.wav", + ) diff --git a/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py b/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0b685739def8c036f319ce76b7dc7b827dba8e --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py @@ -0,0 +1,776 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time +import random +from pathlib import Path +import re +import glob + +import accelerate +import json +import numpy as np +import torch +from accelerate.utils import ProjectConfiguration +from torch.utils.data import DataLoader +from tqdm import tqdm + +import torch +import torch.nn.functional as F +import torchaudio + +from accelerate.logging import get_logger + +from models.codec.facodec.facodec_dataset import FAcodecDataset, FAcodecCollator +from models.codec.codec_sampler import build_samplers +from models.codec.codec_trainer import CodecTrainer + +from modules.dac.nn.loss import ( + MultiScaleSTFTLoss, + MelSpectrogramLoss, + GANLoss, + L1Loss, + FocalLoss, +) +from audiotools import AudioSignal + +from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC + +try: + import nemo.collections.asr as nemo_asr +except ImportError: + print( + "Unable to import nemo_asr, titanet outputs will be set to random values, you may only run debugging mode. DO NOT USE THIS FOR TRAINING" + ) + nemo_asr = None + +from models.codec.facodec.modules.commons import ( + build_model, + load_checkpoint, + load_F0_models, + log_norm, +) +from models.codec.facodec.optimizer import build_optimizer + + +class FAcodecTrainer(CodecTrainer): + def __init__(self, args, cfg): + super().__init__() + + self.args = args + self.cfg = cfg + + cfg.exp_name = args.exp_name + + # Init accelerator + self._init_accelerator() + self.accelerator.wait_for_everyone() + + # Init logger + with self.accelerator.main_process_first(): + self.logger = get_logger(args.exp_name, log_level=args.log_level) + + self.logger.info("=" * 56) + self.logger.info("||\t\t" + "New training process started." + "\t\t||") + self.logger.info("=" * 56) + self.logger.info("\n") + self.logger.debug(f"Using {args.log_level.upper()} logging level.") + self.logger.info(f"Experiment name: {args.exp_name}") + self.logger.info(f"Experiment directory: {self.exp_dir}") + self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") + if self.accelerator.is_main_process: + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") + + # Init training status + self.batch_count: int = 0 + self.step: int = 0 + self.epoch: int = 0 + + self.max_epoch = ( + self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf") + ) + self.logger.info( + "Max epoch: {}".format( + self.max_epoch if self.max_epoch < float("inf") else "Unlimited" + ) + ) + + # Check potential erorrs + if self.accelerator.is_main_process: + self._check_basic_configs() + self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride + self.checkpoints_path = [ + [] for _ in range(len(self.save_checkpoint_stride)) + ] + self.run_eval = self.cfg.train.run_eval + + # Set random seed + with self.accelerator.main_process_first(): + start = time.monotonic_ns() + self._set_random_seed(self.cfg.train.random_seed) + end = time.monotonic_ns() + self.logger.debug( + f"Setting random seed done in {(end - start) / 1e6:.2f}ms" + ) + self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") + + # Build dataloader + with self.accelerator.main_process_first(): + self.logger.info("Building dataset...") + start = time.monotonic_ns() + self.train_dataloader, self.valid_dataloader = self._build_dataloader() + end = time.monotonic_ns() + self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") + + # Build model + with self.accelerator.main_process_first(): + self.logger.info("Building model...") + start = time.monotonic_ns() + self.model = self._build_model() + end = time.monotonic_ns() + for _, model in self.model.items(): + self.logger.debug(model) + self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms") + self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M") + + # Build optimizers and schedulers + with self.accelerator.main_process_first(): + self.logger.info("Building optimizer and scheduler...") + start = time.monotonic_ns() + self.optimizer = self._build_optimizer() + end = time.monotonic_ns() + self.logger.info( + f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms" + ) + + # Build helper models + with self.accelerator.main_process_first(): + self.logger.info("Building helper models...") + start = time.monotonic_ns() + self._built_helper_model() + end = time.monotonic_ns() + self.logger.info( + f"Building helper models done in {(end - start) / 1e6:.2f}ms" + ) + + # Accelerator preparing + self.logger.info("Initializing accelerate...") + start = time.monotonic_ns() + for k in self.model: + self.model[k] = self.accelerator.prepare(self.model[k]) + for k, v in self.optimizer.optimizers.items(): + self.optimizer.optimizers[k] = self.accelerator.prepare( + self.optimizer.optimizers[k] + ) + self.optimizer.schedulers[k] = self.accelerator.prepare( + self.optimizer.schedulers[k] + ) + end = time.monotonic_ns() + self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms") + + # Build criterions + with self.accelerator.main_process_first(): + self.logger.info("Building criterion...") + start = time.monotonic_ns() + self.criterions = self._build_criterion() + end = time.monotonic_ns() + self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms") + + # Resume checkpoints + with self.accelerator.main_process_first(): + self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") + if args.resume_type: + self.logger.info("Resuming from checkpoint...") + start = time.monotonic_ns() + ckpt_path = Path(args.checkpoint) + if self._is_valid_pattern(ckpt_path.parts[-1]): + ckpt_path = self._load_model(args.checkpoint, args.resume_type) + else: + ckpt_path = self._load_model( + args.checkpoint, resume_type=args.resume_type + ) + end = time.monotonic_ns() + self.logger.info( + f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms" + ) + self.checkpoints_path = json.load( + open(os.path.join(ckpt_path, "ckpts.json"), "r") + ) + + if self.accelerator.is_main_process: + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") + + # Save config + self.config_save_path = os.path.join(self.exp_dir, "args.json") + + def _build_dataset(self): + return FAcodecDataset, FAcodecCollator + + def _build_criterion(self): + criterions = dict() + stft_criterion = MultiScaleSTFTLoss() + mel_criterion = MelSpectrogramLoss( + n_mels=[5, 10, 20, 40, 80, 160, 320], + window_lengths=[32, 64, 128, 256, 512, 1024, 2048], + mel_fmin=[0, 0, 0, 0, 0, 0, 0], + mel_fmax=[None, None, None, None, None, None, None], + pow=1.0, + mag_weight=0.0, + clamp_eps=1e-5, + ) + content_criterion = FocalLoss(gamma=2) + l1_criterion = L1Loss() + criterions["stft"] = stft_criterion + criterions["mel"] = mel_criterion + criterions["l1"] = l1_criterion + criterions["content"] = content_criterion + + return criterions + + def _build_model(self): + model = build_model(self.cfg.model_params) + _ = [model[key].to(self.accelerator.device) for key in model] + return model + + def _built_helper_model(self): + device = self.accelerator.device + self.pitch_extractor = load_F0_models(self.cfg.F0_path).to(device) + + # load model and processor + self.w2v_processor = Wav2Vec2Processor.from_pretrained( + "facebook/wav2vec2-xlsr-53-espeak-cv-ft" + ) + self.w2v_model = Wav2Vec2ForCTC.from_pretrained( + "facebook/wav2vec2-xlsr-53-espeak-cv-ft" + ).to(device) + self.w2v_model.eval() + + if nemo_asr is None: + self.speaker_model = None + else: + self.speaker_model = ( + nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + "nvidia/speakerverification_en_titanet_large" + ) + ) + self.speaker_model = self.speaker_model.to(device) + self.speaker_model.eval() + + def _build_optimizer(self): + scheduler_params = { + "warmup_steps": self.cfg.loss_params.warmup_steps, + "base_lr": self.cfg.loss_params.base_lr, + } + optimizer = build_optimizer( + {key: self.model[key] for key in self.model}, + scheduler_params_dict={key: scheduler_params.copy() for key in self.model}, + lr=float(scheduler_params["base_lr"]), + ) + + return optimizer + + def train_loop(self): + """Training process""" + self.accelerator.wait_for_everyone() + + # Dump config + if self.accelerator.is_main_process: + self._dump_cfg(self.config_save_path) + _ = [self.model[key].train() for key in self.model] + self.optimizer.zero_grad() + + # Sync and start training + self.accelerator.wait_for_everyone() + while self.epoch < self.max_epoch: + self.logger.info("\n") + self.logger.info("-" * 32) + self.logger.info("Epoch {}: ".format(self.epoch)) + + # Train and Validate + train_total_loss, train_losses = self._train_epoch() + for key, loss in train_losses.items(): + self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss)) + self.accelerator.log( + {"Epoch/Train {} Loss".format(key): loss}, + step=self.epoch, + ) + self.accelerator.log( + { + "Epoch/Train Total Loss": train_total_loss, + }, + step=self.epoch, + ) + + # Update scheduler + self.accelerator.wait_for_everyone() + + # Check save checkpoint interval + run_eval = False + if self.accelerator.is_main_process: + save_checkpoint = False + for i, num in enumerate(self.save_checkpoint_stride): + if self.epoch % num == 0: + save_checkpoint = True + run_eval |= self.run_eval[i] + + # Save checkpoints + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process and save_checkpoint: + print("Saving..") + state = { + "net": {key: self.model[key].state_dict() for key in self.model}, + "optimizer": self.optimizer.state_dict(), + "scheduler": self.optimizer.scheduler_state_dict(), + "iters": self.step, + "epoch": self.epoch, + } + save_path = os.path.join( + self.checkpoint_dir, + "FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters), + ) + torch.save(state, save_path) + json.dump( + self.checkpoints_path, + open(os.path.join(self.checkpoint_dir, "ckpts.json"), "w"), + ensure_ascii=False, + indent=4, + ) + + self.accelerator.wait_for_everyone() + + self.epoch += 1 + + # Finish training + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process: + path = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}".format( + self.epoch, + self.step, + ), + ) + print("Saving..") + state = { + "net": {key: self.model[key].state_dict() for key in self.model}, + "optimizer": self.optimizer.state_dict(), + "scheduler": self.optimizer.scheduler_state_dict(), + "iters": self.step, + "epoch": self.epoch, + } + save_path = os.path.join( + self.checkpoint_dir, + "FAcodec_epoch_%05d_step_%05d.pth" % (self.epoch, self.iters), + ) + torch.save(state, save_path) + + def _train_epoch(self): + """Training epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + _ = [self.model[key].train() for key in self.model] + + epoch_losses: dict = {} + epoch_total_loss: int = 0 + + for batch in tqdm( + self.train_dataloader, + desc=f"Training Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + # Get losses + total_loss, losses = self._train_step(batch) + self.batch_count += 1 + + # Log info + if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: + self.accelerator.log( + { + "Step/Learning Rate": ( + self.optimizer.schedulers["encoder"].get_last_lr()[0] + if self.step != 0 + else 0 + ) + }, + step=self.step, + ) + for key, _ in losses.items(): + self.accelerator.log( + { + "Step/Train {} Loss".format(key): losses[key], + }, + step=self.step, + ) + + if not epoch_losses: + epoch_losses = losses + else: + for key, value in losses.items(): + epoch_losses[key] += value + epoch_total_loss += total_loss + self.step += 1 + + # Get and log total losses + self.accelerator.wait_for_everyone() + epoch_total_loss = ( + epoch_total_loss + / len(self.train_dataloader) + * self.cfg.train.gradient_accumulation_step + ) + for key in epoch_losses.keys(): + epoch_losses[key] = ( + epoch_losses[key] + / len(self.train_dataloader) + * self.cfg.train.gradient_accumulation_step + ) + return epoch_total_loss, epoch_losses + + def _train_step(self, data): + """Training forward step. Should return average loss of a sample over + one batch. Provoke ``_forward_step`` is recommended except for special case. + See ``_train_epoch`` for usage. + """ + # Init losses + train_losses = {} + total_loss = 0 + + # Use input feature to get predictions + data = [b.to(self.accelerator.device, non_blocking=True) for b in data] + waves, mels, wave_lengths, mel_input_length = data + + # extract semantic latent with w2v model + waves_16k = torchaudio.functional.resample(waves, 24000, 16000) + w2v_input = self.w2v_processor( + waves_16k, sampling_rate=16000, return_tensors="pt" + ).input_values.to(self.accelerator.device) + with torch.no_grad(): + w2v_outputs = self.w2v_model(w2v_input.squeeze(0)).logits + predicted_ids = torch.argmax(w2v_outputs, dim=-1) + phone_ids = ( + F.interpolate( + predicted_ids.unsqueeze(0).float(), mels.size(-1), mode="nearest" + ) + .long() + .squeeze(0) + ) + + # get clips + mel_seg_len = min( + [int(mel_input_length.min().item()), self.cfg.train.max_frame_len] + ) + + gt_mel_seg = [] + wav_seg = [] + w2v_seg = [] + + for bib in range(len(mel_input_length)): + mel_length = int(mel_input_length[bib].item()) + + random_start = ( + np.random.randint(0, mel_length - mel_seg_len) + if mel_length != mel_seg_len + else 0 + ) + gt_mel_seg.append(mels[bib, :, random_start : random_start + mel_seg_len]) + + # w2v_seg.append(w2v_latent[bib, :, random_start:random_start + mel_seg_len]) + w2v_seg.append(phone_ids[bib, random_start : random_start + mel_seg_len]) + + y = waves[bib][random_start * 300 : (random_start + mel_seg_len) * 300] + + wav_seg.append(y.to(self.accelerator.device)) + + gt_mel_seg = torch.stack(gt_mel_seg).detach() + + wav_seg = torch.stack(wav_seg).float().detach().unsqueeze(1) + w2v_seg = torch.stack(w2v_seg).float().detach() + + with torch.no_grad(): + real_norm = log_norm(gt_mel_seg.unsqueeze(1)).squeeze(1).detach() + F0_real, _, _ = self.pitch_extractor(gt_mel_seg.unsqueeze(1)) + + # normalize f0 + # Remove unvoiced frames (replace with -1) + gt_glob_f0s = [] + f0_targets = [] + for bib in range(len(F0_real)): + voiced_indices = F0_real[bib] > 5.0 + f0_voiced = F0_real[bib][voiced_indices] + + if len(f0_voiced) != 0: + # Convert to log scale + log_f0 = f0_voiced.log2() + + # Calculate mean and standard deviation + mean_f0 = log_f0.mean() + std_f0 = log_f0.std() + + # Normalize the F0 sequence + normalized_f0 = (log_f0 - mean_f0) / std_f0 + + # Create the normalized F0 sequence with unvoiced frames + normalized_sequence = torch.zeros_like(F0_real[bib]) + normalized_sequence[voiced_indices] = normalized_f0 + normalized_sequence[~voiced_indices] = ( + -10 + ) # Assign -10 to unvoiced frames + + gt_glob_f0s.append(mean_f0) + else: + normalized_sequence = torch.zeros_like(F0_real[bib]) - 10.0 + gt_glob_f0s.append(torch.tensor(0.0).to(self.accelerator.device)) + + # f0_targets.append(normalized_sequence[single_side_context // 200:-single_side_context // 200]) + f0_targets.append(normalized_sequence) + f0_targets = torch.stack(f0_targets).to(self.accelerator.device) + # fill nan with -10 + f0_targets[torch.isnan(f0_targets)] = -10.0 + # fill inf with -10 + f0_targets[torch.isinf(f0_targets)] = -10.0 + # if frame_rate not equal to 80, interpolate f0 from frame rate of 80 to target frame rate + if self.cfg.preprocess_params.frame_rate != 80: + f0_targets = F.interpolate( + f0_targets.unsqueeze(1), + mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate, + mode="nearest", + ).squeeze(1) + w2v_seg = F.interpolate( + w2v_seg, + mel_seg_len // 80 * self.cfg.preprocess_params.frame_rate, + mode="nearest", + ) + + wav_seg_input = wav_seg + wav_seg_target = wav_seg + + z = self.model.encoder(wav_seg_input) + z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer( + z, wav_seg_input, n_c=2, full_waves=waves, wave_lens=wave_lengths + ) + preds, rev_preds = self.model.fa_predictors(quantized, timbre) + + pred_wave = self.model.decoder(z) + + len_diff = wav_seg_target.size(-1) - pred_wave.size(-1) + if len_diff > 0: + wav_seg_target = wav_seg_target[..., len_diff // 2 : -len_diff // 2] + + # discriminator loss + d_fake = self.model.discriminator(pred_wave.detach()) + d_real = self.model.discriminator(wav_seg_target) + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + + self.optimizer.zero_grad() + self.accelerator.backward(loss_d) + grad_norm_d = torch.nn.utils.clip_grad_norm_( + self.model.discriminator.parameters(), 10.0 + ) + self.optimizer.step("discriminator") + self.optimizer.scheduler(key="discriminator") + + # generator loss + signal = AudioSignal(wav_seg_target, sample_rate=24000) + recons = AudioSignal(pred_wave, sample_rate=24000) + stft_loss = self.criterions["stft"](recons, signal) + mel_loss = self.criterions["mel"](recons, signal) + waveform_loss = self.criterions["l1"](recons, signal) + + d_fake = self.model.discriminator(pred_wave) + d_real = self.model.discriminator(wav_seg_target) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + + pred_f0, pred_uv = preds["f0"], preds["uv"] + rev_pred_f0, rev_pred_uv = rev_preds["rev_f0"], rev_preds["rev_uv"] + + common_min_size = min(pred_f0.size(-2), f0_targets.size(-1)) + f0_targets = f0_targets[..., :common_min_size] + real_norm = real_norm[..., :common_min_size] + + f0_loss = F.smooth_l1_loss( + f0_targets, pred_f0.squeeze(-1)[..., :common_min_size] + ) + uv_loss = F.smooth_l1_loss( + real_norm, pred_uv.squeeze(-1)[..., :common_min_size] + ) + rev_f0_loss = ( + F.smooth_l1_loss(f0_targets, rev_pred_f0.squeeze(-1)[..., :common_min_size]) + if rev_pred_f0 is not None + else torch.FloatTensor([0]).to(self.accelerator.device) + ) + rev_uv_loss = ( + F.smooth_l1_loss(real_norm, rev_pred_uv.squeeze(-1)[..., :common_min_size]) + if rev_pred_uv is not None + else torch.FloatTensor([0]).to(self.accelerator.device) + ) + + tot_f0_loss = f0_loss + rev_f0_loss + tot_uv_loss = uv_loss + rev_uv_loss + + pred_content = preds["content"] + rev_pred_content = rev_preds["rev_content"] + + target_content_latents = w2v_seg[..., :common_min_size] + + content_loss = self.criterions["content"]( + pred_content.transpose(1, 2)[..., :common_min_size], + target_content_latents.long(), + ) + rev_content_loss = ( + self.criterions["content"]( + rev_pred_content.transpose(1, 2)[..., :common_min_size], + target_content_latents.long(), + ) + if rev_pred_content is not None + else torch.FloatTensor([0]).to(self.accelerator.device) + ) + + tot_content_loss = content_loss + rev_content_loss + + if self.speaker_model is not None: + spk_logits = torch.cat( + [ + self.speaker_model.infer_segment(w16.cpu()[..., :wl])[1] + for w16, wl in zip(waves_16k, wave_lengths) + ], + dim=0, + ) + spk_labels = spk_logits.argmax(dim=-1) + else: + spk_labels = torch.zeros([len(waves_16k)], dtype=torch.long).to( + self.accelerator.device + ) + + spk_pred_logits = preds["timbre"] + spk_loss = F.cross_entropy(spk_pred_logits, spk_labels) + x_spk_pred_logits = rev_preds["x_timbre"] + + x_spk_loss = ( + F.cross_entropy(x_spk_pred_logits, spk_labels) + if x_spk_pred_logits is not None + else torch.FloatTensor([0]).to(self.accelerator.device) + ) + + tot_spk_loss = spk_loss + x_spk_loss + + loss_gen_all = ( + mel_loss * 15.0 + + loss_feature * 1.0 + + loss_g * 1.0 + + commitment_loss * 0.25 + + codebook_loss * 1.0 + + tot_f0_loss * 1.0 + + tot_uv_loss * 1.0 + + tot_content_loss * 5.0 + + tot_spk_loss * 5.0 + ) + + self.optimizer.zero_grad() + self.accelerator.backward(loss_gen_all) + + with torch.no_grad(): + total_loss = loss_gen_all.item() + train_losses["stft"] = stft_loss.item() + train_losses["mel"] = mel_loss.item() + train_losses["l1"] = waveform_loss.item() + train_losses["f0"] = f0_loss.item() + train_losses["uv"] = uv_loss.item() + train_losses["content"] = content_loss.item() + train_losses["speaker"] = spk_loss.item() + train_losses["rev_f0"] = rev_f0_loss.item() + train_losses["rev_uv"] = rev_uv_loss.item() + train_losses["rev_content"] = rev_content_loss.item() + train_losses["rev_speaker"] = x_spk_loss.item() + + train_losses["feature"] = loss_feature.item() + train_losses["generator"] = loss_g.item() + train_losses["commitment"] = commitment_loss.item() + train_losses["codebook"] = codebook_loss.item() + + # discriminators + train_losses["discriminator"] = loss_d.item() + + return total_loss, train_losses + + def _inference(self, eval_wave): + """Inference during training for test audios.""" + z = self.model.encoder( + eval_wave[None, None, ...].to(self.accelerator.device).float() + ) + z, quantized, commitment_loss, codebook_loss, timbre = self.model.quantizer( + z, eval_wave[None, None, ...], n_c=self.cfg.model_params.n_c_codebooks + ) + full_pred_wave = self.model.decoder(z) + return full_pred_wave[0] + + def _load_model(self, checkpoint_path=None, resume_type="resume"): + """Load model from checkpoint. If checkpoint_path is None, it will + load the latest checkpoint in checkpoint_dir. If checkpoint_path is not + None, it will load the checkpoint specified by checkpoint_path. **Only use this + method after** ``accelerator.prepare()``. + """ + if resume_type == "resume": + if checkpoint_path is None: + available_checkpoints = glob.glob( + os.path.join(self.checkpoint_dir, "FAcodc_epoch_*_step_*.pth") + ) + # find the checkpoint that has the highest step number + latest_checkpoint = max( + available_checkpoints, + key=lambda x: int(x.split("_")[-1].split(".")[0]), + ) + earliest_checkpoint = min( + available_checkpoints, + key=lambda x: int(x.split("_")[-1].split(".")[0]), + ) + # delete the earliest checkpoint + if ( + earliest_checkpoint != latest_checkpoint + and self.accelerator.is_main_process + and len(available_checkpoints) > 4 + ): + os.remove(earliest_checkpoint) + print(f"Removed {earliest_checkpoint}") + else: + latest_checkpoint = checkpoint_path + + self.model, self.optimizer, self.epoch, self.step = load_checkpoint( + self.model, + self.optimizer, + latest_checkpoint, + load_only_params=False, + ignore_modules=[], + is_distributed=self.accelerator.num_processes > 1, + ) + + else: + raise ValueError("Invalid resume type") + return checkpoint_path + + def _count_parameters(self): + total_num = sum( + sum(p.numel() for p in self.model[key].parameters()) for key in self.model + ) + # trainable_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + return total_num diff --git a/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py b/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py @@ -0,0 +1 @@ + diff --git a/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py b/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py new file mode 100644 index 0000000000000000000000000000000000000000..601ec960795c76be84417bb4e466ac7fe7754cb3 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py @@ -0,0 +1,219 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is borrowed from https://github.com/yl4579/PitchExtractor/blob/main/model.py + +""" +Implementation of model from: +Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using +Convolutional Recurrent Neural Networks" (2019) +Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d +""" +import torch +from torch import nn + + +class JDCNet(nn.Module): + """ + Joint Detection and Classification Network model for singing voice melody. + """ + + def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01): + super().__init__() + self.num_class = num_class + + # input = (b, 1, 31, 513), b = batch size + self.conv_block = nn.Sequential( + nn.Conv2d( + in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False + ), # out: (b, 64, 31, 513) + nn.BatchNorm2d(num_features=64), + nn.LeakyReLU(leaky_relu_slope, inplace=True), + nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513) + ) + + # res blocks + self.res_block1 = ResBlock( + in_channels=64, out_channels=128 + ) # (b, 128, 31, 128) + self.res_block2 = ResBlock( + in_channels=128, out_channels=192 + ) # (b, 192, 31, 32) + self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8) + + # pool block + self.pool_block = nn.Sequential( + nn.BatchNorm2d(num_features=256), + nn.LeakyReLU(leaky_relu_slope, inplace=True), + nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2) + nn.Dropout(p=0.2), + ) + + # maxpool layers (for auxiliary network inputs) + # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2) + self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40)) + # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2) + self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20)) + # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2) + self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10)) + + # in = (b, 640, 31, 2), out = (b, 256, 31, 2) + self.detector_conv = nn.Sequential( + nn.Conv2d(640, 256, 1, bias=False), + nn.BatchNorm2d(256), + nn.LeakyReLU(leaky_relu_slope, inplace=True), + nn.Dropout(p=0.2), + ) + + # input: (b, 31, 512) - resized from (b, 256, 31, 2) + self.bilstm_classifier = nn.LSTM( + input_size=512, hidden_size=256, batch_first=True, bidirectional=True + ) # (b, 31, 512) + + # input: (b, 31, 512) - resized from (b, 256, 31, 2) + self.bilstm_detector = nn.LSTM( + input_size=512, hidden_size=256, batch_first=True, bidirectional=True + ) # (b, 31, 512) + + # input: (b * 31, 512) + self.classifier = nn.Linear( + in_features=512, out_features=self.num_class + ) # (b * 31, num_class) + + # input: (b * 31, 512) + self.detector = nn.Linear( + in_features=512, out_features=2 + ) # (b * 31, 2) - binary classifier + + # initialize weights + self.apply(self.init_weights) + + def get_feature_GAN(self, x): + seq_len = x.shape[-2] + x = x.float().transpose(-1, -2) + + convblock_out = self.conv_block(x) + + resblock1_out = self.res_block1(convblock_out) + resblock2_out = self.res_block2(resblock1_out) + resblock3_out = self.res_block3(resblock2_out) + poolblock_out = self.pool_block[0](resblock3_out) + poolblock_out = self.pool_block[1](poolblock_out) + + return poolblock_out.transpose(-1, -2) + + def get_feature(self, x): + seq_len = x.shape[-2] + x = x.float().transpose(-1, -2) + + convblock_out = self.conv_block(x) + + resblock1_out = self.res_block1(convblock_out) + resblock2_out = self.res_block2(resblock1_out) + resblock3_out = self.res_block3(resblock2_out) + poolblock_out = self.pool_block[0](resblock3_out) + poolblock_out = self.pool_block[1](poolblock_out) + + return self.pool_block[2](poolblock_out) + + def forward(self, x): + """ + Returns: + classification_prediction, detection_prediction + sizes: (b, 31, 722), (b, 31, 2) + """ + ############################### + # forward pass for classifier # + ############################### + seq_len = x.shape[-1] + x = x.float().transpose(-1, -2) + + convblock_out = self.conv_block(x) + + resblock1_out = self.res_block1(convblock_out) + resblock2_out = self.res_block2(resblock1_out) + resblock3_out = self.res_block3(resblock2_out) + + poolblock_out = self.pool_block[0](resblock3_out) + poolblock_out = self.pool_block[1](poolblock_out) + GAN_feature = poolblock_out.transpose(-1, -2) + poolblock_out = self.pool_block[2](poolblock_out) + + # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512) + classifier_out = ( + poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512)) + ) + classifier_out, _ = self.bilstm_classifier( + classifier_out + ) # ignore the hidden states + + classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512) + classifier_out = self.classifier(classifier_out) + classifier_out = classifier_out.view( + (-1, seq_len, self.num_class) + ) # (b, 31, num_class) + + # sizes: (b, 31, 722), (b, 31, 2) + # classifier output consists of predicted pitch classes per frame + # detector output consists of: (isvoice, notvoice) estimates per frame + return torch.abs(classifier_out.squeeze(-1)), GAN_feature, poolblock_out + + @staticmethod + def init_weights(m): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell): + for p in m.parameters(): + if p.data is None: + continue + + if len(p.shape) >= 2: + nn.init.orthogonal_(p.data) + else: + nn.init.normal_(p.data) + + +class ResBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01): + super().__init__() + self.downsample = in_channels != out_channels + + # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper + self.pre_conv = nn.Sequential( + nn.BatchNorm2d(num_features=in_channels), + nn.LeakyReLU(leaky_relu_slope, inplace=True), + nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only + ) + + # conv layers + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(leaky_relu_slope, inplace=True), + nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), + ) + + # 1 x 1 convolution layer to match the feature dimensions + self.conv1by1 = None + if self.downsample: + self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False) + + def forward(self, x): + x = self.pre_conv(x) + if self.downsample: + x = self.conv(x) + self.conv1by1(x) + else: + x = self.conv(x) + x + return x diff --git a/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py b/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..c29854fd97cefc66125301003198b2da6ea1e9be --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py @@ -0,0 +1,437 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/attentions.py + +import copy +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from . import commons + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + window_size=4, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + **kwargs + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + device=x.device, dtype=x.dtype + ) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + p_dropout=0.0, + window_size=None, + heads_share=True, + block_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert ( + t_s == t_t + ), "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys( + query / math.sqrt(self.k_channels), key_relative_embeddings + ) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert ( + t_s == t_t + ), "Local attention is only available for self-attention." + block_mask = ( + torch.ones_like(scores) + .triu(-self.block_length) + .tril(self.block_length) + ) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s + ) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings + ) + output = ( + output.transpose(2, 3).contiguous().view(b, d, t_t) + ) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[ + :, slice_start_position:slice_end_position + ] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad( + x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) + ) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad( + x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) + ) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, + in_channels, + out_channels, + filter_channels, + kernel_size, + p_dropout=0.0, + activation=None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/indextts/utils/maskgct/models/codec/facodec/modules/commons.py b/indextts/utils/maskgct/models/codec/facodec/modules/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..89baaf4b06426595b7be1ab9ca4d94c5c99779d6 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/modules/commons.py @@ -0,0 +1,331 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import math +import os.path + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from munch import Munch +import json + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def slice_segments_audio(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to( + dtype=torch.long + ) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm + + +def log_norm(x, mean=-4, std=4, dim=2): + """ + normalized log mel -> mel -> norm -> log(norm) + """ + x = torch.log(torch.exp(x * std + mean).norm(dim=dim)) + return x + + +from huggingface_hub import hf_hub_download + + +def load_F0_models(path): + # load F0 model + from .JDC.model import JDCNet + + F0_model = JDCNet(num_class=1, seq_len=192) + if not os.path.exists(path): + path = hf_hub_download(repo_id="Plachta/JDCnet", filename="bst.t7") + params = torch.load(path, map_location="cpu")["net"] + F0_model.load_state_dict(params) + _ = F0_model.train() + + return F0_model + + +# Generators +from modules.dac.model.dac import Encoder, Decoder +from .quantize import FAquantizer, FApredictors + +# Discriminators +from modules.dac.model.discriminator import Discriminator + + +def build_model(args): + encoder = Encoder( + d_model=args.DAC.encoder_dim, + strides=args.DAC.encoder_rates, + d_latent=1024, + causal=args.causal, + lstm=args.lstm, + ) + + quantizer = FAquantizer( + in_dim=1024, + n_p_codebooks=1, + n_c_codebooks=args.n_c_codebooks, + n_t_codebooks=2, + n_r_codebooks=3, + codebook_size=1024, + codebook_dim=8, + quantizer_dropout=0.5, + causal=args.causal, + separate_prosody_encoder=args.separate_prosody_encoder, + timbre_norm=args.timbre_norm, + ) + + fa_predictors = FApredictors( + in_dim=1024, + use_gr_content_f0=args.use_gr_content_f0, + use_gr_prosody_phone=args.use_gr_prosody_phone, + use_gr_residual_f0=True, + use_gr_residual_phone=True, + use_gr_timbre_content=True, + use_gr_timbre_prosody=args.use_gr_timbre_prosody, + use_gr_x_timbre=True, + norm_f0=args.norm_f0, + timbre_norm=args.timbre_norm, + use_gr_content_global_f0=args.use_gr_content_global_f0, + ) + + decoder = Decoder( + input_channel=1024, + channels=args.DAC.decoder_dim, + rates=args.DAC.decoder_rates, + causal=args.causal, + lstm=args.lstm, + ) + + discriminator = Discriminator( + rates=[], + periods=[2, 3, 5, 7, 11], + fft_sizes=[2048, 1024, 512], + sample_rate=args.DAC.sr, + bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)], + ) + + nets = Munch( + encoder=encoder, + quantizer=quantizer, + decoder=decoder, + discriminator=discriminator, + fa_predictors=fa_predictors, + ) + + return nets + + +def load_checkpoint( + model, + optimizer, + path, + load_only_params=True, + ignore_modules=[], + is_distributed=False, +): + state = torch.load(path, map_location="cpu") + params = state["net"] + for key in model: + if key in params and key not in ignore_modules: + if not is_distributed: + # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix + for k in list(params[key].keys()): + if k.startswith("module."): + params[key][k[len("module.") :]] = params[key][k] + del params[key][k] + print("%s loaded" % key) + model[key].load_state_dict(params[key], strict=True) + _ = [model[key].eval() for key in model] + + if not load_only_params: + epoch = state["epoch"] + 1 + iters = state["iters"] + optimizer.load_state_dict(state["optimizer"]) + optimizer.load_scheduler_state_dict(state["scheduler"]) + + else: + epoch = state["epoch"] + 1 + iters = state["iters"] + + return model, optimizer, epoch, iters + + +def recursive_munch(d): + if isinstance(d, dict): + return Munch((k, recursive_munch(v)) for k, v in d.items()) + elif isinstance(d, list): + return [recursive_munch(v) for v in d] + else: + return d diff --git a/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py b/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py new file mode 100644 index 0000000000000000000000000000000000000000..d09396ea20c653b2a443e144ab429f534ce033fd --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch.autograd import Function +import torch +from torch import nn + + +class GradientReversal(Function): + @staticmethod + def forward(ctx, x, alpha): + ctx.save_for_backward(x, alpha) + return x + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + _, alpha = ctx.saved_tensors + if ctx.needs_input_grad[0]: + grad_input = -alpha * grad_output + return grad_input, None + + +revgrad = GradientReversal.apply + + +class GradientReversal(nn.Module): + def __init__(self, alpha): + super().__init__() + self.alpha = torch.tensor(alpha, requires_grad=False) + + def forward(self, x): + return revgrad(x, self.alpha) diff --git a/indextts/utils/maskgct/models/codec/facodec/modules/layers.py b/indextts/utils/maskgct/models/codec/facodec/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..186cbe7bef653fe7cbd6dbd40f38624457b8ecfa --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/modules/layers.py @@ -0,0 +1,460 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import torch +from torch import nn +from typing import Optional, Any +from torch import Tensor +import torch.nn.functional as F +import torchaudio +import torchaudio.functional as audio_F + +import random + +random.seed(0) + + +def _get_activation_fn(activ): + if activ == "relu": + return nn.ReLU() + elif activ == "lrelu": + return nn.LeakyReLU(0.2) + elif activ == "swish": + return lambda x: x * torch.sigmoid(x) + else: + raise RuntimeError( + "Unexpected activ type %s, expected [relu, lrelu, swish]" % activ + ) + + +class LinearNorm(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"): + super(LinearNorm, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain) + ) + + def forward(self, x): + return self.linear_layer(x) + + +class ConvNorm(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=None, + dilation=1, + bias=True, + w_init_gain="linear", + param=None, + ): + super(ConvNorm, self).__init__() + if padding is None: + assert kernel_size % 2 == 1 + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = torch.nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + + torch.nn.init.xavier_uniform_( + self.conv.weight, + gain=torch.nn.init.calculate_gain(w_init_gain, param=param), + ) + + def forward(self, signal): + conv_signal = self.conv(signal) + return conv_signal + + +class CausualConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=1, + dilation=1, + bias=True, + w_init_gain="linear", + param=None, + ): + super(CausualConv, self).__init__() + if padding is None: + assert kernel_size % 2 == 1 + padding = int(dilation * (kernel_size - 1) / 2) * 2 + else: + self.padding = padding * 2 + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + dilation=dilation, + bias=bias, + ) + + torch.nn.init.xavier_uniform_( + self.conv.weight, + gain=torch.nn.init.calculate_gain(w_init_gain, param=param), + ) + + def forward(self, x): + x = self.conv(x) + x = x[:, :, : -self.padding] + return x + + +class CausualBlock(nn.Module): + def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"): + super(CausualBlock, self).__init__() + self.blocks = nn.ModuleList( + [ + self._get_conv( + hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p + ) + for i in range(n_conv) + ] + ) + + def forward(self, x): + for block in self.blocks: + res = x + x = block(x) + x += res + return x + + def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2): + layers = [ + CausualConv( + hidden_dim, + hidden_dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + ), + _get_activation_fn(activ), + nn.BatchNorm1d(hidden_dim), + nn.Dropout(p=dropout_p), + CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), + _get_activation_fn(activ), + nn.Dropout(p=dropout_p), + ] + return nn.Sequential(*layers) + + +class ConvBlock(nn.Module): + def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"): + super().__init__() + self._n_groups = 8 + self.blocks = nn.ModuleList( + [ + self._get_conv( + hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p + ) + for i in range(n_conv) + ] + ) + + def forward(self, x): + for block in self.blocks: + res = x + x = block(x) + x += res + return x + + def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2): + layers = [ + ConvNorm( + hidden_dim, + hidden_dim, + kernel_size=3, + padding=dilation, + dilation=dilation, + ), + _get_activation_fn(activ), + nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), + nn.Dropout(p=dropout_p), + ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), + _get_activation_fn(activ), + nn.Dropout(p=dropout_p), + ] + return nn.Sequential(*layers) + + +class LocationLayer(nn.Module): + def __init__(self, attention_n_filters, attention_kernel_size, attention_dim): + super(LocationLayer, self).__init__() + padding = int((attention_kernel_size - 1) / 2) + self.location_conv = ConvNorm( + 2, + attention_n_filters, + kernel_size=attention_kernel_size, + padding=padding, + bias=False, + stride=1, + dilation=1, + ) + self.location_dense = LinearNorm( + attention_n_filters, attention_dim, bias=False, w_init_gain="tanh" + ) + + def forward(self, attention_weights_cat): + processed_attention = self.location_conv(attention_weights_cat) + processed_attention = processed_attention.transpose(1, 2) + processed_attention = self.location_dense(processed_attention) + return processed_attention + + +class Attention(nn.Module): + def __init__( + self, + attention_rnn_dim, + embedding_dim, + attention_dim, + attention_location_n_filters, + attention_location_kernel_size, + ): + super(Attention, self).__init__() + self.query_layer = LinearNorm( + attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh" + ) + self.memory_layer = LinearNorm( + embedding_dim, attention_dim, bias=False, w_init_gain="tanh" + ) + self.v = LinearNorm(attention_dim, 1, bias=False) + self.location_layer = LocationLayer( + attention_location_n_filters, attention_location_kernel_size, attention_dim + ) + self.score_mask_value = -float("inf") + + def get_alignment_energies(self, query, processed_memory, attention_weights_cat): + """ + PARAMS + ------ + query: decoder output (batch, n_mel_channels * n_frames_per_step) + processed_memory: processed encoder outputs (B, T_in, attention_dim) + attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) + RETURNS + ------- + alignment (batch, max_time) + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v( + torch.tanh(processed_query + processed_attention_weights + processed_memory) + ) + + energies = energies.squeeze(-1) + return energies + + def forward( + self, + attention_hidden_state, + memory, + processed_memory, + attention_weights_cat, + mask, + ): + """ + PARAMS + ------ + attention_hidden_state: attention rnn last output + memory: encoder outputs + processed_memory: processed encoder outputs + attention_weights_cat: previous and cummulative attention weights + mask: binary mask for padded data + """ + alignment = self.get_alignment_energies( + attention_hidden_state, processed_memory, attention_weights_cat + ) + + if mask is not None: + alignment.data.masked_fill_(mask, self.score_mask_value) + + attention_weights = F.softmax(alignment, dim=1) + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights + + +class ForwardAttentionV2(nn.Module): + def __init__( + self, + attention_rnn_dim, + embedding_dim, + attention_dim, + attention_location_n_filters, + attention_location_kernel_size, + ): + super(ForwardAttentionV2, self).__init__() + self.query_layer = LinearNorm( + attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh" + ) + self.memory_layer = LinearNorm( + embedding_dim, attention_dim, bias=False, w_init_gain="tanh" + ) + self.v = LinearNorm(attention_dim, 1, bias=False) + self.location_layer = LocationLayer( + attention_location_n_filters, attention_location_kernel_size, attention_dim + ) + self.score_mask_value = -float(1e20) + + def get_alignment_energies(self, query, processed_memory, attention_weights_cat): + """ + PARAMS + ------ + query: decoder output (batch, n_mel_channels * n_frames_per_step) + processed_memory: processed encoder outputs (B, T_in, attention_dim) + attention_weights_cat: prev. and cumulative att weights (B, 2, max_time) + RETURNS + ------- + alignment (batch, max_time) + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v( + torch.tanh(processed_query + processed_attention_weights + processed_memory) + ) + + energies = energies.squeeze(-1) + return energies + + def forward( + self, + attention_hidden_state, + memory, + processed_memory, + attention_weights_cat, + mask, + log_alpha, + ): + """ + PARAMS + ------ + attention_hidden_state: attention rnn last output + memory: encoder outputs + processed_memory: processed encoder outputs + attention_weights_cat: previous and cummulative attention weights + mask: binary mask for padded data + """ + log_energy = self.get_alignment_energies( + attention_hidden_state, processed_memory, attention_weights_cat + ) + + # log_energy = + + if mask is not None: + log_energy.data.masked_fill_(mask, self.score_mask_value) + + # attention_weights = F.softmax(alignment, dim=1) + + # content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME] + # log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1] + + # log_total_score = log_alpha + content_score + + # previous_attention_weights = attention_weights_cat[:,0,:] + + log_alpha_shift_padded = [] + max_time = log_energy.size(1) + for sft in range(2): + shifted = log_alpha[:, : max_time - sft] + shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value) + log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) + + biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2) + + log_alpha_new = biased + log_energy + + attention_weights = F.softmax(log_alpha_new, dim=1) + + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights, log_alpha_new + + +class PhaseShuffle2d(nn.Module): + def __init__(self, n=2): + super(PhaseShuffle2d, self).__init__() + self.n = n + self.random = random.Random(1) + + def forward(self, x, move=None): + # x.size = (B, C, M, L) + if move is None: + move = self.random.randint(-self.n, self.n) + + if move == 0: + return x + else: + left = x[:, :, :, :move] + right = x[:, :, :, move:] + shuffled = torch.cat([right, left], dim=3) + return shuffled + + +class PhaseShuffle1d(nn.Module): + def __init__(self, n=2): + super(PhaseShuffle1d, self).__init__() + self.n = n + self.random = random.Random(1) + + def forward(self, x, move=None): + # x.size = (B, C, M, L) + if move is None: + move = self.random.randint(-self.n, self.n) + + if move == 0: + return x + else: + left = x[:, :, :move] + right = x[:, :, move:] + shuffled = torch.cat([right, left], dim=2) + + return shuffled + + +class MFCC(nn.Module): + def __init__(self, n_mfcc=40, n_mels=80): + super(MFCC, self).__init__() + self.n_mfcc = n_mfcc + self.n_mels = n_mels + self.norm = "ortho" + dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm) + self.register_buffer("dct_mat", dct_mat) + + def forward(self, mel_specgram): + if len(mel_specgram.shape) == 2: + mel_specgram = mel_specgram.unsqueeze(0) + unsqueezed = True + else: + unsqueezed = False + # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) + # -> (channel, time, n_mfcc).tranpose(...) + mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) + + # unpack batch + if unsqueezed: + mfcc = mfcc.squeeze(0) + return mfcc diff --git a/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py b/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..d9cb55be173ff1aa7a15cccab7c7ac51fbd042c1 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py @@ -0,0 +1,741 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from modules.dac.nn.quantize import ResidualVectorQuantize +from torch import nn +from .wavenet import WN +from .style_encoder import StyleEncoder +from .gradient_reversal import GradientReversal +import torch +import torchaudio +import torchaudio.functional as audio_F +import numpy as np +from ..alias_free_torch import * +from torch.nn.utils import weight_norm +from torch import nn, sin, pow +from einops.layers.torch import Rearrange +from modules.dac.model.encodec import SConv1d + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False + ): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta := x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + return x + self.block(x) + + +class CNNLSTM(nn.Module): + def __init__(self, indim, outdim, head, global_pred=False): + super().__init__() + self.global_pred = global_pred + self.model = nn.Sequential( + ResidualUnit(indim, dilation=1), + ResidualUnit(indim, dilation=2), + ResidualUnit(indim, dilation=3), + Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)), + Rearrange("b c t -> b t c"), + ) + self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)]) + + def forward(self, x): + # x: [B, C, T] + x = self.model(x) + if self.global_pred: + x = torch.mean(x, dim=1, keepdim=False) + outs = [head(x) for head in self.heads] + return outs + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +class MFCC(nn.Module): + def __init__(self, n_mfcc=40, n_mels=80): + super(MFCC, self).__init__() + self.n_mfcc = n_mfcc + self.n_mels = n_mels + self.norm = "ortho" + dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm) + self.register_buffer("dct_mat", dct_mat) + + def forward(self, mel_specgram): + if len(mel_specgram.shape) == 2: + mel_specgram = mel_specgram.unsqueeze(0) + unsqueezed = True + else: + unsqueezed = False + # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) + # -> (channel, time, n_mfcc).tranpose(...) + mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) + + # unpack batch + if unsqueezed: + mfcc = mfcc.squeeze(0) + return mfcc + + +class FAquantizer(nn.Module): + def __init__( + self, + in_dim=1024, + n_p_codebooks=1, + n_c_codebooks=2, + n_t_codebooks=2, + n_r_codebooks=3, + codebook_size=1024, + codebook_dim=8, + quantizer_dropout=0.5, + causal=False, + separate_prosody_encoder=False, + timbre_norm=False, + ): + super(FAquantizer, self).__init__() + conv1d_type = SConv1d # if causal else nn.Conv1d + self.prosody_quantizer = ResidualVectorQuantize( + input_dim=in_dim, + n_codebooks=n_p_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.content_quantizer = ResidualVectorQuantize( + input_dim=in_dim, + n_codebooks=n_c_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + if not timbre_norm: + self.timbre_quantizer = ResidualVectorQuantize( + input_dim=in_dim, + n_codebooks=n_t_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + else: + self.timbre_encoder = StyleEncoder( + in_dim=80, hidden_dim=512, out_dim=in_dim + ) + self.timbre_linear = nn.Linear(1024, 1024 * 2) + self.timbre_linear.bias.data[:1024] = 1 + self.timbre_linear.bias.data[1024:] = 0 + self.timbre_norm = nn.LayerNorm(1024, elementwise_affine=False) + + self.residual_quantizer = ResidualVectorQuantize( + input_dim=in_dim, + n_codebooks=n_r_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + if separate_prosody_encoder: + self.melspec_linear = conv1d_type( + in_channels=20, out_channels=256, kernel_size=1, causal=causal + ) + self.melspec_encoder = WN( + hidden_channels=256, + kernel_size=5, + dilation_rate=1, + n_layers=8, + gin_channels=0, + p_dropout=0.2, + causal=causal, + ) + self.melspec_linear2 = conv1d_type( + in_channels=256, out_channels=1024, kernel_size=1, causal=causal + ) + else: + pass + self.separate_prosody_encoder = separate_prosody_encoder + + self.prob_random_mask_residual = 0.75 + + SPECT_PARAMS = { + "n_fft": 2048, + "win_length": 1200, + "hop_length": 300, + } + MEL_PARAMS = { + "n_mels": 80, + } + + self.to_mel = torchaudio.transforms.MelSpectrogram( + n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS + ) + self.mel_mean, self.mel_std = -4, 4 + self.frame_rate = 24000 / 300 + self.hop_length = 300 + + self.is_timbre_norm = timbre_norm + if timbre_norm: + self.forward = self.forward_v2 + + def preprocess(self, wave_tensor, n_bins=20): + mel_tensor = self.to_mel(wave_tensor.squeeze(1)) + mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std + return mel_tensor[:, :n_bins, : int(wave_tensor.size(-1) / self.hop_length)] + + @torch.no_grad() + def decode(self, codes): + code_c, code_p, code_t = codes.split([1, 1, 2], dim=1) + + z_c = self.content_quantizer.from_codes(code_c)[0] + z_p = self.prosody_quantizer.from_codes(code_p)[0] + z_t = self.timbre_quantizer.from_codes(code_t)[0] + + z = z_c + z_p + z_t + + return z, [z_c, z_p, z_t] + + @torch.no_grad() + def encode(self, x, wave_segments, n_c=1): + outs = 0 + if self.separate_prosody_encoder: + prosody_feature = self.preprocess(wave_segments) + + f0_input = prosody_feature # (B, T, 20) + f0_input = self.melspec_linear(f0_input) + f0_input = self.melspec_encoder( + f0_input, + torch.ones(f0_input.shape[0], 1, f0_input.shape[2]) + .to(f0_input.device) + .bool(), + ) + f0_input = self.melspec_linear2(f0_input) + + common_min_size = min(f0_input.size(2), x.size(2)) + f0_input = f0_input[:, :, :common_min_size] + + x = x[:, :, :common_min_size] + + ( + z_p, + codes_p, + latents_p, + commitment_loss_p, + codebook_loss_p, + ) = self.prosody_quantizer(f0_input, 1) + outs += z_p.detach() + else: + ( + z_p, + codes_p, + latents_p, + commitment_loss_p, + codebook_loss_p, + ) = self.prosody_quantizer(x, 1) + outs += z_p.detach() + + ( + z_c, + codes_c, + latents_c, + commitment_loss_c, + codebook_loss_c, + ) = self.content_quantizer(x, n_c) + outs += z_c.detach() + + timbre_residual_feature = x - z_p.detach() - z_c.detach() + + ( + z_t, + codes_t, + latents_t, + commitment_loss_t, + codebook_loss_t, + ) = self.timbre_quantizer(timbre_residual_feature, 2) + outs += z_t # we should not detach timbre + + residual_feature = timbre_residual_feature - z_t + + ( + z_r, + codes_r, + latents_r, + commitment_loss_r, + codebook_loss_r, + ) = self.residual_quantizer(residual_feature, 3) + + return [codes_c, codes_p, codes_t, codes_r], [z_c, z_p, z_t, z_r] + + def forward( + self, x, wave_segments, noise_added_flags, recon_noisy_flags, n_c=2, n_t=2 + ): + # timbre = self.timbre_encoder(mels, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1)) + # timbre = self.timbre_encoder(mel_segments, torch.ones(mel_segments.size(0), 1, mel_segments.size(2)).bool().to(mel_segments.device)) + outs = 0 + if self.separate_prosody_encoder: + prosody_feature = self.preprocess(wave_segments) + + f0_input = prosody_feature # (B, T, 20) + f0_input = self.melspec_linear(f0_input) + f0_input = self.melspec_encoder( + f0_input, + torch.ones(f0_input.shape[0], 1, f0_input.shape[2]) + .to(f0_input.device) + .bool(), + ) + f0_input = self.melspec_linear2(f0_input) + + common_min_size = min(f0_input.size(2), x.size(2)) + f0_input = f0_input[:, :, :common_min_size] + + x = x[:, :, :common_min_size] + + ( + z_p, + codes_p, + latents_p, + commitment_loss_p, + codebook_loss_p, + ) = self.prosody_quantizer(f0_input, 1) + outs += z_p.detach() + else: + ( + z_p, + codes_p, + latents_p, + commitment_loss_p, + codebook_loss_p, + ) = self.prosody_quantizer(x, 1) + outs += z_p.detach() + + ( + z_c, + codes_c, + latents_c, + commitment_loss_c, + codebook_loss_c, + ) = self.content_quantizer(x, n_c) + outs += z_c.detach() + + timbre_residual_feature = x - z_p.detach() - z_c.detach() + + ( + z_t, + codes_t, + latents_t, + commitment_loss_t, + codebook_loss_t, + ) = self.timbre_quantizer(timbre_residual_feature, n_t) + outs += z_t # we should not detach timbre + + residual_feature = timbre_residual_feature - z_t + + ( + z_r, + codes_r, + latents_r, + commitment_loss_r, + codebook_loss_r, + ) = self.residual_quantizer(residual_feature, 3) + + bsz = z_r.shape[0] + res_mask = np.random.choice( + [0, 1], + size=bsz, + p=[ + self.prob_random_mask_residual, + 1 - self.prob_random_mask_residual, + ], + ) + res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1) + res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype) + noise_must_on = noise_added_flags * recon_noisy_flags + noise_must_off = noise_added_flags * (~recon_noisy_flags) + res_mask[noise_must_on] = 1 + res_mask[noise_must_off] = 0 + + outs += z_r * res_mask + + quantized = [z_p, z_c, z_t, z_r] + commitment_losses = ( + commitment_loss_p + + commitment_loss_c + + commitment_loss_t + + commitment_loss_r + ) + codebook_losses = ( + codebook_loss_p + codebook_loss_c + codebook_loss_t + codebook_loss_r + ) + + return outs, quantized, commitment_losses, codebook_losses + + def forward_v2( + self, + x, + wave_segments, + n_c=1, + n_t=2, + full_waves=None, + wave_lens=None, + return_codes=False, + ): + # timbre = self.timbre_encoder(x, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1)) + if full_waves is None: + mel = self.preprocess(wave_segments, n_bins=80) + timbre = self.timbre_encoder( + mel, torch.ones(mel.size(0), 1, mel.size(2)).bool().to(mel.device) + ) + else: + mel = self.preprocess(full_waves, n_bins=80) + timbre = self.timbre_encoder( + mel, + sequence_mask(wave_lens // self.hop_length, mel.size(-1)).unsqueeze(1), + ) + outs = 0 + if self.separate_prosody_encoder: + prosody_feature = self.preprocess(wave_segments) + + f0_input = prosody_feature # (B, T, 20) + f0_input = self.melspec_linear(f0_input) + f0_input = self.melspec_encoder( + f0_input, + torch.ones(f0_input.shape[0], 1, f0_input.shape[2]) + .to(f0_input.device) + .bool(), + ) + f0_input = self.melspec_linear2(f0_input) + + common_min_size = min(f0_input.size(2), x.size(2)) + f0_input = f0_input[:, :, :common_min_size] + + x = x[:, :, :common_min_size] + + ( + z_p, + codes_p, + latents_p, + commitment_loss_p, + codebook_loss_p, + ) = self.prosody_quantizer(f0_input, 1) + outs += z_p.detach() + else: + ( + z_p, + codes_p, + latents_p, + commitment_loss_p, + codebook_loss_p, + ) = self.prosody_quantizer(x, 1) + outs += z_p.detach() + + ( + z_c, + codes_c, + latents_c, + commitment_loss_c, + codebook_loss_c, + ) = self.content_quantizer(x, n_c) + outs += z_c.detach() + + residual_feature = x - z_p.detach() - z_c.detach() + + ( + z_r, + codes_r, + latents_r, + commitment_loss_r, + codebook_loss_r, + ) = self.residual_quantizer(residual_feature, 3) + + bsz = z_r.shape[0] + res_mask = np.random.choice( + [0, 1], + size=bsz, + p=[ + self.prob_random_mask_residual, + 1 - self.prob_random_mask_residual, + ], + ) + res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1) + res_mask = res_mask.to(device=z_r.device, dtype=z_r.dtype) + + if not self.training: + res_mask = torch.ones_like(res_mask) + outs += z_r * res_mask + + quantized = [z_p, z_c, z_r] + codes = [codes_p, codes_c, codes_r] + commitment_losses = commitment_loss_p + commitment_loss_c + commitment_loss_r + codebook_losses = codebook_loss_p + codebook_loss_c + codebook_loss_r + + style = self.timbre_linear(timbre).unsqueeze(2) # (B, 2d, 1) + gamma, beta = style.chunk(2, 1) # (B, d, 1) + outs = outs.transpose(1, 2) + outs = self.timbre_norm(outs) + outs = outs.transpose(1, 2) + outs = outs * gamma + beta + + if return_codes: + return outs, quantized, commitment_losses, codebook_losses, timbre, codes + else: + return outs, quantized, commitment_losses, codebook_losses, timbre + + def voice_conversion(self, z, ref_wave): + ref_mel = self.preprocess(ref_wave, n_bins=80) + ref_timbre = self.timbre_encoder( + ref_mel, + sequence_mask( + torch.LongTensor([ref_wave.size(-1)]).to(z.device) // self.hop_length, + ref_mel.size(-1), + ).unsqueeze(1), + ) + style = self.timbre_linear(ref_timbre).unsqueeze(2) # (B, 2d, 1) + gamma, beta = style.chunk(2, 1) # (B, d, 1) + outs = z.transpose(1, 2) + outs = self.timbre_norm(outs) + outs = outs.transpose(1, 2) + outs = outs * gamma + beta + + return outs + + +class FApredictors(nn.Module): + def __init__( + self, + in_dim=1024, + use_gr_content_f0=False, + use_gr_prosody_phone=False, + use_gr_residual_f0=False, + use_gr_residual_phone=False, + use_gr_timbre_content=True, + use_gr_timbre_prosody=True, + use_gr_x_timbre=False, + norm_f0=True, + timbre_norm=False, + use_gr_content_global_f0=False, + ): + super(FApredictors, self).__init__() + self.f0_predictor = CNNLSTM(in_dim, 1, 2) + self.phone_predictor = CNNLSTM(in_dim, 1024, 1) + if timbre_norm: + self.timbre_predictor = nn.Linear(in_dim, 20000) + else: + self.timbre_predictor = CNNLSTM(in_dim, 20000, 1, global_pred=True) + + self.use_gr_content_f0 = use_gr_content_f0 + self.use_gr_prosody_phone = use_gr_prosody_phone + self.use_gr_residual_f0 = use_gr_residual_f0 + self.use_gr_residual_phone = use_gr_residual_phone + self.use_gr_timbre_content = use_gr_timbre_content + self.use_gr_timbre_prosody = use_gr_timbre_prosody + self.use_gr_x_timbre = use_gr_x_timbre + + self.rev_f0_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 2) + ) + self.rev_content_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1024, 1) + ) + self.rev_timbre_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_dim, 20000, 1, global_pred=True) + ) + + self.norm_f0 = norm_f0 + self.timbre_norm = timbre_norm + if timbre_norm: + self.forward = self.forward_v2 + self.global_f0_predictor = nn.Linear(in_dim, 1) + + self.use_gr_content_global_f0 = use_gr_content_global_f0 + if use_gr_content_global_f0: + self.rev_global_f0_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 1, global_pred=True) + ) + + def forward(self, quantized): + prosody_latent = quantized[0] + content_latent = quantized[1] + timbre_latent = quantized[2] + residual_latent = quantized[3] + content_pred = self.phone_predictor(content_latent)[0] + + if self.norm_f0: + spk_pred = self.timbre_predictor(timbre_latent)[0] + f0_pred, uv_pred = self.f0_predictor(prosody_latent) + else: + spk_pred = self.timbre_predictor(timbre_latent + prosody_latent)[0] + f0_pred, uv_pred = self.f0_predictor(prosody_latent + timbre_latent) + + prosody_rev_latent = torch.zeros_like(quantized[0]) + if self.use_gr_content_f0: + prosody_rev_latent += quantized[1] + if self.use_gr_timbre_prosody: + prosody_rev_latent += quantized[2] + if self.use_gr_residual_f0: + prosody_rev_latent += quantized[3] + rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent) + + content_rev_latent = torch.zeros_like(quantized[1]) + if self.use_gr_prosody_phone: + content_rev_latent += quantized[0] + if self.use_gr_timbre_content: + content_rev_latent += quantized[2] + if self.use_gr_residual_phone: + content_rev_latent += quantized[3] + rev_content_pred = self.rev_content_predictor(content_rev_latent)[0] + + if self.norm_f0: + timbre_rev_latent = quantized[0] + quantized[1] + quantized[3] + else: + timbre_rev_latent = quantized[1] + quantized[3] + if self.use_gr_x_timbre: + x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0] + else: + x_spk_pred = None + + preds = { + "f0": f0_pred, + "uv": uv_pred, + "content": content_pred, + "timbre": spk_pred, + } + + rev_preds = { + "rev_f0": rev_f0_pred, + "rev_uv": rev_uv_pred, + "rev_content": rev_content_pred, + "x_timbre": x_spk_pred, + } + return preds, rev_preds + + def forward_v2(self, quantized, timbre): + prosody_latent = quantized[0] + content_latent = quantized[1] + residual_latent = quantized[2] + content_pred = self.phone_predictor(content_latent)[0] + + spk_pred = self.timbre_predictor(timbre) + f0_pred, uv_pred = self.f0_predictor(prosody_latent) + + prosody_rev_latent = torch.zeros_like(prosody_latent) + if self.use_gr_content_f0: + prosody_rev_latent += content_latent + if self.use_gr_residual_f0: + prosody_rev_latent += residual_latent + rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent) + + content_rev_latent = torch.zeros_like(content_latent) + if self.use_gr_prosody_phone: + content_rev_latent += prosody_latent + if self.use_gr_residual_phone: + content_rev_latent += residual_latent + rev_content_pred = self.rev_content_predictor(content_rev_latent)[0] + + timbre_rev_latent = prosody_latent + content_latent + residual_latent + if self.use_gr_x_timbre: + x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0] + else: + x_spk_pred = None + + preds = { + "f0": f0_pred, + "uv": uv_pred, + "content": content_pred, + "timbre": spk_pred, + } + + rev_preds = { + "rev_f0": rev_f0_pred, + "rev_uv": rev_uv_pred, + "rev_content": rev_content_pred, + "x_timbre": x_spk_pred, + } + return preds, rev_preds diff --git a/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py b/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e437c1adfc823af8f2324a24c4801b130eb69191 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py @@ -0,0 +1,110 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/styleencoder.py + +from . import attentions +from torch import nn +import torch +from torch.nn import functional as F + + +class Mish(nn.Module): + def __init__(self): + super(Mish, self).__init__() + + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class Conv1dGLU(nn.Module): + """ + Conv1d + GLU(Gated Linear Unit) with residual connection. + For GLU refer to https://arxiv.org/abs/1612.08083 paper. + """ + + def __init__(self, in_channels, out_channels, kernel_size, dropout): + super(Conv1dGLU, self).__init__() + self.out_channels = out_channels + self.conv1 = nn.Conv1d( + in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2 + ) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + residual = x + x = self.conv1(x) + x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1) + x = x1 * torch.sigmoid(x2) + x = residual + self.dropout(x) + return x + + +class StyleEncoder(torch.nn.Module): + def __init__(self, in_dim=513, hidden_dim=128, out_dim=256): + + super().__init__() + + self.in_dim = in_dim # Linear 513 wav2vec 2.0 1024 + self.hidden_dim = hidden_dim + self.out_dim = out_dim + self.kernel_size = 5 + self.n_head = 2 + self.dropout = 0.1 + + self.spectral = nn.Sequential( + nn.Conv1d(self.in_dim, self.hidden_dim, 1), + Mish(), + nn.Dropout(self.dropout), + nn.Conv1d(self.hidden_dim, self.hidden_dim, 1), + Mish(), + nn.Dropout(self.dropout), + ) + + self.temporal = nn.Sequential( + Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), + Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), + ) + + self.slf_attn = attentions.MultiHeadAttention( + self.hidden_dim, + self.hidden_dim, + self.n_head, + p_dropout=self.dropout, + proximal_bias=False, + proximal_init=True, + ) + self.atten_drop = nn.Dropout(self.dropout) + self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1) + + def forward(self, x, mask=None): + + # spectral + x = self.spectral(x) * mask + # temporal + x = self.temporal(x) * mask + + # self-attention + attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1) + y = self.slf_attn(x, x, attn_mask=attn_mask) + x = x + self.atten_drop(y) + + # fc + x = self.fc(x) + + # temoral average pooling + w = self.temporal_avg_pool(x, mask=mask) + + return w + + def temporal_avg_pool(self, x, mask=None): + if mask is None: + out = torch.mean(x, dim=2) + else: + len_ = mask.sum(dim=2) + x = x.sum(dim=2) + + out = torch.div(x, len_) + return out diff --git a/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py b/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py new file mode 100644 index 0000000000000000000000000000000000000000..3a2da541be26d54f0ccc098637334ad812bc2374 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py @@ -0,0 +1,224 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/modules.py + +import math +import torch +from torch import nn +from torch.nn import functional as F + +from modules.dac.model.encodec import SConv1d + +from . import commons + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append( + nn.Conv1d( + in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append( + nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dialted and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__( + self, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + p_dropout=0, + causal=False, + ): + super(WN, self).__init__() + conv1d_type = SConv1d + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + self.cond_layer = conv1d_type( + gin_channels, 2 * hidden_channels * n_layers, 1, norm="weight_norm" + ) + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = conv1d_type( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + norm="weight_norm", + causal=causal, + ) + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = conv1d_type( + hidden_channels, res_skip_channels, 1, norm="weight_norm", causal=causal + ) + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) diff --git a/indextts/utils/maskgct/models/codec/facodec/optimizer.py b/indextts/utils/maskgct/models/codec/facodec/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4d6d798d5f20a137a5140834d64407c423012673 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/facodec/optimizer.py @@ -0,0 +1,104 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os, sys +import os.path as osp +import numpy as np +import torch +from torch import nn +from torch.optim import Optimizer +from functools import reduce +from torch.optim import AdamW + + +class MultiOptimizer: + def __init__(self, optimizers={}, schedulers={}): + self.optimizers = optimizers + self.schedulers = schedulers + self.keys = list(optimizers.keys()) + self.param_groups = reduce( + lambda x, y: x + y, [v.param_groups for v in self.optimizers.values()] + ) + + def state_dict(self): + state_dicts = [(key, self.optimizers[key].state_dict()) for key in self.keys] + return state_dicts + + def scheduler_state_dict(self): + state_dicts = [(key, self.schedulers[key].state_dict()) for key in self.keys] + return state_dicts + + def load_state_dict(self, state_dict): + for key, val in state_dict: + try: + self.optimizers[key].load_state_dict(val) + except: + print("Unloaded %s" % key) + + def load_scheduler_state_dict(self, state_dict): + for key, val in state_dict: + try: + self.schedulers[key].load_state_dict(val) + except: + print("Unloaded %s" % key) + + def step(self, key=None, scaler=None): + keys = [key] if key is not None else self.keys + _ = [self._step(key, scaler) for key in keys] + + def _step(self, key, scaler=None): + if scaler is not None: + scaler.step(self.optimizers[key]) + scaler.update() + else: + self.optimizers[key].step() + + def zero_grad(self, key=None): + if key is not None: + self.optimizers[key].zero_grad() + else: + _ = [self.optimizers[key].zero_grad() for key in self.keys] + + def scheduler(self, *args, key=None): + if key is not None: + self.schedulers[key].step(*args) + else: + _ = [self.schedulers[key].step_batch(*args) for key in self.keys] + + +def define_scheduler(optimizer, params): + scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params["gamma"]) + + return scheduler + + +def build_optimizer(model_dict, scheduler_params_dict, lr, type="AdamW"): + optim = {} + for key, model in model_dict.items(): + model_parameters = model.parameters() + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + if type == "AdamW": + optim[key] = AdamW( + model_parameters, + lr=lr, + betas=(0.9, 0.98), + eps=1e-9, + weight_decay=0.1, + ) + else: + raise ValueError("Unknown optimizer type: %s" % type) + + schedulers = dict( + [ + (key, torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.999996)) + for key, opt in optim.items() + ] + ) + + multi_optim = MultiOptimizer(optim, schedulers) + return multi_optim diff --git a/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py b/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fdc9c3e164d161430b9bb03e1b351da16cb82a4f --- /dev/null +++ b/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py @@ -0,0 +1,210 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from concurrent.futures import ALL_COMPLETED +import numpy as np +import torch +import torch.nn as nn + +from torch.nn import functional as F +from einops import rearrange, repeat + +from indextts.utils.maskgct.models.codec.amphion_codec.quantize import ResidualVQ +from indextts.utils.maskgct.models.codec.kmeans.vocos import VocosBackbone + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +def compute_codebook_perplexity(indices, codebook_size): + indices = indices.flatten() + prob = torch.bincount(indices, minlength=codebook_size).float() / indices.size(0) + perp = torch.exp(-torch.sum(prob * torch.log(prob + 1e-10))) + return perp + + +class RepCodec(nn.Module): + def __init__( + self, + codebook_size=8192, + hidden_size=1024, + codebook_dim=8, + vocos_dim=384, + vocos_intermediate_dim=2048, + vocos_num_layers=12, + num_quantizers=1, + downsample_scale=1, + cfg=None, + ): + super().__init__() + codebook_size = ( + cfg.codebook_size + if cfg is not None and hasattr(cfg, "codebook_size") + else codebook_size + ) + codebook_dim = ( + cfg.codebook_dim + if cfg is not None and hasattr(cfg, "codebook_dim") + else codebook_dim + ) + hidden_size = ( + cfg.hidden_size + if cfg is not None and hasattr(cfg, "hidden_size") + else hidden_size + ) + vocos_dim = ( + cfg.vocos_dim + if cfg is not None and hasattr(cfg, "vocos_dim") + else vocos_dim + ) + vocos_intermediate_dim = ( + cfg.vocos_intermediate_dim + if cfg is not None and hasattr(cfg, "vocos_dim") + else vocos_intermediate_dim + ) + vocos_num_layers = ( + cfg.vocos_num_layers + if cfg is not None and hasattr(cfg, "vocos_dim") + else vocos_num_layers + ) + num_quantizers = ( + cfg.num_quantizers + if cfg is not None and hasattr(cfg, "num_quantizers") + else num_quantizers + ) + downsample_scale = ( + cfg.downsample_scale + if cfg is not None and hasattr(cfg, "downsample_scale") + else downsample_scale + ) + + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.hidden_size = hidden_size + self.vocos_dim = vocos_dim + self.vocos_intermediate_dim = vocos_intermediate_dim + self.vocos_num_layers = vocos_num_layers + self.num_quantizers = num_quantizers + self.downsample_scale = downsample_scale + + if self.downsample_scale != None and self.downsample_scale > 1: + self.down = nn.Conv1d( + self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1 + ) + self.up = nn.Conv1d( + self.hidden_size, self.hidden_size, kernel_size=3, stride=1, padding=1 + ) + + self.encoder = nn.Sequential( + VocosBackbone( + input_channels=self.hidden_size, + dim=self.vocos_dim, + intermediate_dim=self.vocos_intermediate_dim, + num_layers=self.vocos_num_layers, + adanorm_num_embeddings=None, + ), + nn.Linear(self.vocos_dim, self.hidden_size), + ) + self.decoder = nn.Sequential( + VocosBackbone( + input_channels=self.hidden_size, + dim=self.vocos_dim, + intermediate_dim=self.vocos_intermediate_dim, + num_layers=self.vocos_num_layers, + adanorm_num_embeddings=None, + ), + nn.Linear(self.vocos_dim, self.hidden_size), + ) + + self.quantizer = ResidualVQ( + input_dim=hidden_size, + num_quantizers=num_quantizers, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_type="fvq", + quantizer_dropout=0.0, + commitment=0.15, + codebook_loss_weight=1.0, + use_l2_normlize=True, + ) + + self.reset_parameters() + + def forward(self, x): + + # downsample + if self.downsample_scale != None and self.downsample_scale > 1: + x = x.transpose(1, 2) + x = self.down(x) + x = F.gelu(x) + x = x.transpose(1, 2) + + # encoder + x = self.encoder(x.transpose(1, 2)).transpose(1, 2) + + # vq + ( + quantized_out, + all_indices, + all_commit_losses, + all_codebook_losses, + _, + ) = self.quantizer(x) + + # decoder + x = self.decoder(quantized_out) + + # up + if self.downsample_scale != None and self.downsample_scale > 1: + x = x.transpose(1, 2) + x = F.interpolate(x, scale_factor=2, mode="nearest") + x_rec = self.up(x).transpose(1, 2) + + codebook_loss = (all_codebook_losses + all_commit_losses).mean() + all_indices = all_indices + + return x_rec, codebook_loss, all_indices + + def quantize(self, x): + + if self.downsample_scale != None and self.downsample_scale > 1: + x = x.transpose(1, 2) + x = self.down(x) + x = F.gelu(x) + x = x.transpose(1, 2) + + x = self.encoder(x.transpose(1, 2)).transpose(1, 2) + + ( + quantized_out, + all_indices, + all_commit_losses, + all_codebook_losses, + _, + ) = self.quantizer(x) + + if all_indices.shape[0] == 1: + return all_indices.squeeze(0), quantized_out.transpose(1, 2) + return all_indices, quantized_out.transpose(1, 2) + + def reset_parameters(self): + self.apply(init_weights) + + +if __name__ == "__main__": + repcodec = RepCodec(vocos_dim=1024, downsample_scale=2) + print(repcodec) + print(sum(p.numel() for p in repcodec.parameters()) / 1e6) + x = torch.randn(5, 10, 1024) + x_rec, codebook_loss, all_indices = repcodec(x) + print(x_rec.shape, codebook_loss, all_indices.shape) + vq_id, emb = repcodec.quantize(x) + print(vq_id.shape, emb.shape) diff --git a/indextts/utils/maskgct/models/codec/kmeans/vocos.py b/indextts/utils/maskgct/models/codec/kmeans/vocos.py new file mode 100644 index 0000000000000000000000000000000000000000..7d9f5a5ce9d8d4283ac313caeb983d3280afacd9 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/kmeans/vocos.py @@ -0,0 +1,850 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import numpy as np +import scipy +import torch +from torch import nn, view_as_real, view_as_complex +from torch import nn +from torch.nn.utils import weight_norm, remove_weight_norm +from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) + + +def symlog(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.log1p(x.abs()) + + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) + + +class STFT(nn.Module): + def __init__( + self, + n_fft: int, + hop_length: int, + win_length: int, + center=True, + ): + super().__init__() + self.center = center + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, T * hop_length) + + if not self.center: + pad = self.win_length - self.hop_length + x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect") + + stft_spec = torch.stft( + x, + self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=self.center, + return_complex=False, + ) # (B, n_fft // 2 + 1, T, 2) + + rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2) + imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2) + + log_mag = torch.log( + torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5 + ) # (B, n_fft // 2 + 1, T) + phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T) + + return log_mag, phase + + +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__( + self, n_fft: int, hop_length: int, win_length: int, padding: str = "same" + ): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft( + spec, + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + ) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + assert (window_envelope > 1e-11).all() + y = y / window_envelope + + return y + + +class MDCT(nn.Module): + """ + Modified Discrete Cosine Transform (MDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len) + post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N) + # view_as_real: NCCL Backend does not support ComplexFloat data type + # https://github.com/pytorch/pytorch/issues/71613 + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + """ + Apply the Modified Discrete Cosine Transform (MDCT) to the input audio. + + Args: + audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size + and T is the length of the audio. + + Returns: + Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames + and N is the number of frequency bins. + """ + if self.padding == "center": + audio = torch.nn.functional.pad( + audio, (self.frame_len // 2, self.frame_len // 2) + ) + elif self.padding == "same": + # hop_length is 1/2 frame_len + audio = torch.nn.functional.pad( + audio, (self.frame_len // 4, self.frame_len // 4) + ) + else: + raise ValueError("Padding must be 'center' or 'same'.") + + x = audio.unfold(-1, self.frame_len, self.frame_len // 2) + N = self.frame_len // 2 + x = x * self.window.expand(x.shape) + X = torch.fft.fft( + x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1 + )[..., :N] + res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N) + return torch.real(res) * np.sqrt(2) + + +class IMDCT(nn.Module): + """ + Inverse Modified Discrete Cosine Transform (IMDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N) + post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2)) + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + """ + Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients. + + Args: + X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size, + L is the number of frames, and N is the number of frequency bins. + + Returns: + Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio. + """ + B, L, N = X.shape + Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device) + Y[..., :N] = X + Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,))) + y = torch.fft.ifft( + Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1 + ) + y = ( + torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) + * np.sqrt(N) + * np.sqrt(2) + ) + result = y * self.window.expand(y.shape) + output_size = (1, (L + 1) * N) + audio = torch.nn.functional.fold( + result.transpose(1, 2), + output_size=output_size, + kernel_size=(1, self.frame_len), + stride=(1, self.frame_len // 2), + )[:, 0, 0, :] + + if self.padding == "center": + pad = self.frame_len // 2 + elif self.padding == "same": + pad = self.frame_len // 4 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + audio = audio[:, pad:-pad] + return audio + + +class FourierHead(nn.Module): + """Base class for inverse fourier modules.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class ISTFTHead(FourierHead): + """ + ISTFT Head module for predicting STFT complex coefficients. + + Args: + dim (int): Hidden dimension of the model. + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames, which should align with + the resolution of the input features. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): + super().__init__() + out_dim = n_fft + 2 + self.out = torch.nn.Linear(dim, out_dim) + self.istft = ISTFT( + n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ISTFTHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x).transpose(1, 2) + mag, p = x.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip( + mag, max=1e2 + ) # safeguard to prevent excessively large magnitudes + # wrapping happens here. These two lines produce real and imaginary value + x = torch.cos(p) + y = torch.sin(p) + # recalculating phase here does not produce anything new + # only costs time + # phase = torch.atan2(y, x) + # S = mag * torch.exp(phase * 1j) + # better directly produce the complex value + S = mag * (x + 1j * y) + audio = self.istft(S) + return audio + + +class IMDCTSymExpHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with symmetric exponential function + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized + based on perceptual scaling. Defaults to None. + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__( + self, + dim: int, + mdct_frame_len: int, + padding: str = "same", + sample_rate: Optional[int] = None, + clip_audio: bool = False, + ): + super().__init__() + out_dim = mdct_frame_len // 2 + self.out = nn.Linear(dim, out_dim) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + self.clip_audio = clip_audio + + if sample_rate is not None: + # optionally init the last layer following mel-scale + m_max = _hz_to_mel(sample_rate // 2) + m_pts = torch.linspace(0, m_max, out_dim) + f_pts = _mel_to_hz(m_pts) + scale = 1 - (f_pts / f_pts.max()) + + with torch.no_grad(): + self.out.weight.mul_(scale.view(-1, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTSymExpHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + x = symexp(x) + x = torch.clip( + x, min=-1e2, max=1e2 + ) # safeguard to prevent excessively large magnitudes + audio = self.imdct(x) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + + return audio + + +class IMDCTCosHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p) + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__( + self, + dim: int, + mdct_frame_len: int, + padding: str = "same", + clip_audio: bool = False, + ): + super().__init__() + self.clip_audio = clip_audio + self.out = nn.Linear(dim, mdct_frame_len) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTCosHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + m, p = x.chunk(2, dim=2) + m = torch.exp(m).clip( + max=1e2 + ) # safeguard to prevent excessively large magnitudes + audio = self.imdct(m * torch.cos(p)) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + return audio + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: float, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward( + self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None + ) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + num_embeddings (int): Number of embeddings. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim + ) + self.shift = nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim + ) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding_id) + shift = self.shift(cond_embedding_id) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale + shift + return x + + +class ResBlock1(nn.Module): + """ + ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, + but without upsampling layers. + + Args: + dim (int): Number of input channels. + kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. + dilation (tuple[int], optional): Dilation factors for the dilated convolutions. + Defaults to (1, 3, 5). + lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. + Defaults to 0.1. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + dilation: Tuple[int, int, int] = (1, 3, 5), + lrelu_slope: float = 0.1, + layer_scale_init_value: Optional[float] = None, + ): + super().__init__() + self.lrelu_slope = lrelu_slope + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[0], + padding=self.get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[1], + padding=self.get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[2], + padding=self.get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + ), + ] + ) + + self.gamma = nn.ParameterList( + [ + ( + nn.Parameter( + layer_scale_init_value * torch.ones(dim, 1), requires_grad=True + ) + if layer_scale_init_value is not None + else None + ), + ( + nn.Parameter( + layer_scale_init_value * torch.ones(dim, 1), requires_grad=True + ) + if layer_scale_init_value is not None + else None + ), + ( + nn.Parameter( + layer_scale_init_value * torch.ones(dim, 1), requires_grad=True + ) + if layer_scale_init_value is not None + else None + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): + xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) + xt = c2(xt) + if gamma is not None: + xt = gamma * xt + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + @staticmethod + def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +class Backbone(nn.Module): + """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + + Returns: + Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, + and H denotes the model dimension. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class VocosBackbone(Backbone): + """ + Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. + num_layers (int): Number of ConvNeXtBlock layers. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional model. Defaults to None. + """ + + def __init__( + self, + input_channels: int, + dim: int, + intermediate_dim: int, + num_layers: int, + layer_scale_init_value: Optional[float] = None, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + layer_scale_init_value = layer_scale_init_value or 1 / num_layers + self.convnext = nn.ModuleList( + [ + ConvNeXtBlock( + dim=dim, + intermediate_dim=intermediate_dim, + layer_scale_init_value=layer_scale_init_value, + adanorm_num_embeddings=adanorm_num_embeddings, + ) + for _ in range(num_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + bandwidth_id = kwargs.get("bandwidth_id", None) + x = self.embed(x) + if self.adanorm: + assert bandwidth_id is not None + x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) + else: + x = self.norm(x.transpose(1, 2)) + x = x.transpose(1, 2) + for conv_block in self.convnext: + x = conv_block(x, cond_embedding_id=bandwidth_id) + x = self.final_layer_norm(x.transpose(1, 2)) + return x + + +class VocosResNetBackbone(Backbone): + """ + Vocos backbone module built with ResBlocks. + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + num_blocks (int): Number of ResBlock1 blocks. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. + """ + + def __init__( + self, + input_channels, + dim, + num_blocks, + layer_scale_init_value=None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = weight_norm( + nn.Conv1d(input_channels, dim, kernel_size=3, padding=1) + ) + layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 + self.resnet = nn.Sequential( + *[ + ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) + for _ in range(num_blocks) + ] + ) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.embed(x) + x = self.resnet(x) + x = x.transpose(1, 2) + return x + + +class Vocos(nn.Module): + def __init__( + self, + input_channels: int = 256, + dim: int = 384, + intermediate_dim: int = 1152, + num_layers: int = 8, + adanorm_num_embeddings: int = 4, + n_fft: int = 800, + hop_size: int = 200, + padding: str = "same", + ): + super().__init__() + + self.backbone = VocosBackbone( + input_channels=input_channels, + dim=dim, + intermediate_dim=intermediate_dim, + num_layers=num_layers, + adanorm_num_embeddings=adanorm_num_embeddings, + ) + self.head = ISTFTHead(dim, n_fft, hop_size, padding) + + def forward(self, x): + x = self.backbone(x) + x = self.head(x) + + return x[:, None, :] diff --git a/indextts/utils/maskgct/models/codec/melvqgan/melspec.py b/indextts/utils/maskgct/models/codec/melvqgan/melspec.py new file mode 100755 index 0000000000000000000000000000000000000000..0cfa6b8768623cd283f64482a08bae4700e79ad5 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/melvqgan/melspec.py @@ -0,0 +1,108 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import pyworld as pw +import numpy as np +import soundfile as sf +import os +from torchaudio.functional import pitch_shift +import librosa +from librosa.filters import mel as librosa_mel_fn +import torch.nn as nn +import torch.nn.functional as F +import tqdm + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +class MelSpectrogram(nn.Module): + def __init__( + self, + n_fft, + num_mels, + sampling_rate, + hop_size, + win_size, + fmin, + fmax, + center=False, + ): + super(MelSpectrogram, self).__init__() + self.n_fft = n_fft + self.hop_size = hop_size + self.win_size = win_size + self.sampling_rate = sampling_rate + self.num_mels = num_mels + self.fmin = fmin + self.fmax = fmax + self.center = center + + mel_basis = {} + hann_window = {} + + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis = torch.from_numpy(mel).float() + hann_window = torch.hann_window(win_size) + + self.register_buffer("mel_basis", mel_basis) + self.register_buffer("hann_window", hann_window) + + def forward(self, y): + y = torch.nn.functional.pad( + y.unsqueeze(1), + ( + int((self.n_fft - self.hop_size) / 2), + int((self.n_fft - self.hop_size) / 2), + ), + mode="reflect", + ) + y = y.squeeze(1) + spec = torch.stft( + y, + self.n_fft, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=self.center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(self.mel_basis, spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/indextts/utils/maskgct/models/codec/ns3_codec/README.md b/indextts/utils/maskgct/models/codec/ns3_codec/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1283d677f41d16f72577586ceb57bcf49241280c --- /dev/null +++ b/indextts/utils/maskgct/models/codec/ns3_codec/README.md @@ -0,0 +1,216 @@ +## FACodec: Speech Codec with Attribute Factorization used for NaturalSpeech 3 + +[![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/pdf/2403.03100.pdf) +[![demo](https://img.shields.io/badge/FACodec-Demo-red)](https://speechresearch.github.io/naturalspeech3/) +[![model](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Models-pink)](https://huggingface.co/amphion/naturalspeech3_facodec) +[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Spaces-yellow)](https://huggingface.co/spaces/amphion/naturalspeech3_facodec) + +## Overview + +FACodec is a core component of the advanced text-to-speech (TTS) model NaturalSpeech 3. FACodec converts complex speech waveform into disentangled subspaces representing speech attributes of content, prosody, timbre, and acoustic details and reconstruct high-quality speech waveform from these attributes. FACodec decomposes complex speech into subspaces representing different attributes, thus simplifying the modeling of speech representation. + +Research can use FACodec to develop different modes of TTS models, such as non-autoregressive based discrete diffusion (NaturalSpeech 3) or autoregressive models (like VALL-E). + +
+
+ +
+
+ +
+
+ +
+
+ +## Useage + +Download the pre-trained FACodec model from HuggingFace: [Pretrained FACodec checkpoint](https://huggingface.co/amphion/naturalspeech3_facodec) + +Install Amphion +```bash +git clone https://github.com/open-mmlab/Amphion.git +``` + +Few lines of code to use the pre-trained FACodec model +```python +from Amphion.models.codec.ns3_codec import FACodecEncoder, FACodecDecoder +from huggingface_hub import hf_hub_download + +fa_encoder = FACodecEncoder( + ngf=32, + up_ratios=[2, 4, 5, 5], + out_channels=256, +) + +fa_decoder = FACodecDecoder( + in_channels=256, + upsample_initial_channel=1024, + ngf=32, + up_ratios=[5, 5, 4, 2], + vq_num_q_c=2, + vq_num_q_p=1, + vq_num_q_r=3, + vq_dim=256, + codebook_dim=8, + codebook_size_prosody=10, + codebook_size_content=10, + codebook_size_residual=10, + use_gr_x_timbre=True, + use_gr_residual_f0=True, + use_gr_residual_phone=True, +) + +encoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin") +decoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin") + +fa_encoder.load_state_dict(torch.load(encoder_ckpt)) +fa_decoder.load_state_dict(torch.load(decoder_ckpt)) + +fa_encoder.eval() +fa_decoder.eval() + +``` + +Inference +```python +test_wav_path = "test.wav" +test_wav = librosa.load(test_wav_path, sr=16000)[0] +test_wav = torch.from_numpy(test_wav).float() +test_wav = test_wav.unsqueeze(0).unsqueeze(0) + +with torch.no_grad(): + + # encode + enc_out = fa_encoder(test_wav) + print(enc_out.shape) + + # quantize + vq_post_emb, vq_id, _, quantized, spk_embs = fa_decoder(enc_out, eval_vq=False, vq=True) + + # latent after quantization + print(vq_post_emb.shape) + + # codes + print("vq id shape:", vq_id.shape) + + # get prosody code + prosody_code = vq_id[:1] + print("prosody code shape:", prosody_code.shape) + + # get content code + cotent_code = vq_id[1:3] + print("content code shape:", cotent_code.shape) + + # get residual code (acoustic detail codes) + residual_code = vq_id[3:] + print("residual code shape:", residual_code.shape) + + # speaker embedding + print("speaker embedding shape:", spk_embs.shape) + + # decode (recommand) + recon_wav = fa_decoder.inference(vq_post_emb, spk_embs) + print(recon_wav.shape) + sf.write("recon.wav", recon_wav[0][0].cpu().numpy(), 16000) +``` + +FACodec can achieve zero-shot voice conversion with FACodecEncoderV2/FACodecDecoderV2 or FACodecRedecoder +```python +from Amphion.models.codec.ns3_codec import FACodecEncoderV2, FACodecDecoderV2 + +# Same parameters as FACodecEncoder/FACodecDecoder +fa_encoder_v2 = FACodecEncoderV2(...) +fa_decoder_v2 = FACodecDecoderV2(...) + +encoder_v2_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder_v2.bin") +decoder_v2_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder_v2.bin") + +fa_encoder_v2.load_state_dict(torch.load(encoder_v2_ckpt)) +fa_decoder_v2.load_state_dict(torch.load(decoder_v2_ckpt)) + +with torch.no_grad(): + enc_out_a = fa_encoder_v2(wav_a) + prosody_a = fa_encoder_v2.get_prosody_feature(wav_a) + enc_out_b = fa_encoder_v2(wav_b) + prosody_b = fa_encoder_v2.get_prosody_feature(wav_b) + + vq_post_emb_a, vq_id_a, _, quantized, spk_embs_a = fa_decoder_v2( + enc_out_a, prosody_a, eval_vq=False, vq=True + ) + vq_post_emb_b, vq_id_b, _, quantized, spk_embs_b = fa_decoder_v2( + enc_out_b, prosody_b, eval_vq=False, vq=True + ) + + vq_post_emb_a_to_b = fa_decoder_v2.vq2emb(vq_id_a, use_residual=False) + recon_wav_a_to_b = fa_decoder_v2.inference(vq_post_emb_a_to_b, spk_embs_b) +``` + +or + +```python +from Amphion.models.codec.ns3_codec import FACodecRedecoder + +fa_redecoder = FACodecRedecoder() + +redecoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_redecoder.bin") + +fa_redecoder.load_state_dict(torch.load(redecoder_ckpt)) + +with torch.no_grad(): + enc_out_a = fa_encoder(wav_a) + enc_out_b = fa_encoder(wav_b) + + vq_post_emb_a, vq_id_a, _, quantized_a, spk_embs_a = fa_decoder(enc_out_a, eval_vq=False, vq=True) + vq_post_emb_b, vq_id_b, _, quantized_b, spk_embs_b = fa_decoder(enc_out_b, eval_vq=False, vq=True) + + # convert speaker + vq_post_emb_a_to_b = fa_redecoder.vq2emb(vq_id_a, spk_embs_b, use_residual=False) + recon_wav_a_to_b = fa_redecoder.inference(vq_post_emb_a_to_b, spk_embs_b) + + sf.write("recon_a_to_b.wav", recon_wav_a_to_b[0][0].cpu().numpy(), 16000) +``` + +## Q&A + +Q1: What audio sample rate does FACodec support? What is the hop size? How many codes will be generated for each frame? + +A1: FACodec supports 16KHz speech audio. The hop size is 200 samples, and (16000/200) * 6 (total number of codebooks) codes will be generated for each frame. + +Q2: Is it possible to train an autoregressive TTS model like VALL-E using FACodec? + +A2: Yes. In fact, the authors of NaturalSpeech 3 have already employ explore the autoregressive generative model for discrete token generation with FACodec. They use an autoregressive language model to generate prosody codes, followed by a non-autoregressive model to generate the remaining content and acoustic details codes. + +Q3: Is it possible to train a latent diffusion TTS model like NaturalSpeech2 using FACodec? + +A3: Yes. You can use the latent getted after quanzaition as the modelling target for the latent diffusion model. + +Q4: Can FACodec compress and reconstruct audio from other domains? Such as sound effects, music, etc. + +A4: Since FACodec is designed for speech, it may not be suitable for other audio domains. However, it is possible to use the FACodec model to compress and reconstruct audio from other domains, but the quality may not be as good as the original audio. + +Q5: Can FACodec be used for content feature for some other tasks like voice conversion? + +A5: I think the answer is yes. Researchers can use the content code of FACodec as the content feature for voice conversion. We hope to see more research in this direction. + +## Citations + +If you use our FACodec model, please cite the following paper: + +```bibtex +@article{ju2024naturalspeech, + title={NaturalSpeech 3: Zero-Shot Speech Synthesis with Factorized Codec and Diffusion Models}, + author={Ju, Zeqian and Wang, Yuancheng and Shen, Kai and Tan, Xu and Xin, Detai and Yang, Dongchao and Liu, Yanqing and Leng, Yichong and Song, Kaitao and Tang, Siliang and others}, + journal={arXiv preprint arXiv:2403.03100}, + year={2024} +} + +@article{zhang2023amphion, + title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit}, + author={Xueyao Zhang and Liumeng Xue and Yicheng Gu and Yuancheng Wang and Haorui He and Chaoren Wang and Xi Chen and Zihao Fang and Haopeng Chen and Junan Zhang and Tze Ying Tang and Lexiao Zou and Mingxuan Wang and Jun Han and Kai Chen and Haizhou Li and Zhizheng Wu}, + journal={arXiv}, + year={2024}, + volume={abs/2312.09911} +} +``` + diff --git a/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py b/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f0e4c194e9a02d93f7f1c8f4bed05f460b36e20 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .facodec import * diff --git a/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py b/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3bccdb97a3706bcb7149f48e04178cf00a5e877 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py @@ -0,0 +1,5 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +from .filter import * +from .resample import * +from .act import * diff --git a/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py b/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py new file mode 100644 index 0000000000000000000000000000000000000000..779d58d5f1e889f8b639dd019a0ce951e69e4cfb --- /dev/null +++ b/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py @@ -0,0 +1,29 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch.nn as nn +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x diff --git a/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py b/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..ece8e02fce0e65e13522e990a80d1bfeeffd46ba --- /dev/null +++ b/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py @@ -0,0 +1,96 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if "sinc" in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +def kaiser_sinc_filter1d( + cutoff, half_width, kernel_size +): # return filter [1,1,kernel_size] + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = "replicate", + kernel_size: int = 12, + ): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + # input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + return out diff --git a/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py b/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..ee993b10339141b469b67c3e11f5d73c5f4e0bca --- /dev/null +++ b/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py @@ -0,0 +1,57 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + xx = self.lowpass(x) + + return xx diff --git a/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py b/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py new file mode 100644 index 0000000000000000000000000000000000000000..87f661bdfa250e5be7514946934f6873a981b9c3 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py @@ -0,0 +1,1222 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn, sin, pow +from torch.nn import Parameter +import torch.nn.functional as F +from torch.nn.utils import weight_norm +from .alias_free_torch import * +from .quantize import * +from einops import rearrange +from einops.layers.torch import Rearrange +from .transformer import TransformerEncoder +from .gradient_reversal import GradientReversal +from .melspec import MelSpectrogram + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +class CNNLSTM(nn.Module): + def __init__(self, indim, outdim, head, global_pred=False): + super().__init__() + self.global_pred = global_pred + self.model = nn.Sequential( + ResidualUnit(indim, dilation=1), + ResidualUnit(indim, dilation=2), + ResidualUnit(indim, dilation=3), + Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)), + Rearrange("b c t -> b t c"), + ) + self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)]) + + def forward(self, x): + # x: [B, C, T] + x = self.model(x) + if self.global_pred: + x = torch.mean(x, dim=1, keepdim=False) + outs = [head(x) for head in self.heads] + return outs + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False + ): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta := x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + return x + self.block(x) + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Activation1d(activation=SnakeBeta(dim // 2, alpha_logscale=True)), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + ), + ) + + def forward(self, x): + return self.block(x) + + +class FACodecEncoder(nn.Module): + def __init__( + self, + ngf=32, + up_ratios=(2, 4, 5, 5), + out_channels=1024, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.up_ratios = up_ratios + + # Create first convolution + d_model = ngf + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in up_ratios: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)), + WNConv1d(d_model, out_channels, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + self.reset_parameters() + + def forward(self, x): + out = self.block(x) + return out + + def inference(self, x): + return self.block(x) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + Activation1d(activation=SnakeBeta(input_dim, alpha_logscale=True)), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + output_padding=stride % 2, + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class FACodecDecoder(nn.Module): + def __init__( + self, + in_channels=256, + upsample_initial_channel=1536, + ngf=32, + up_ratios=(5, 5, 4, 2), + vq_num_q_c=2, + vq_num_q_p=1, + vq_num_q_r=3, + vq_dim=1024, + vq_commit_weight=0.005, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_dim=8, + codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size + codebook_size_content=10, + codebook_size_residual=10, + quantizer_dropout=0.0, + dropout_type="linear", + use_gr_content_f0=False, + use_gr_prosody_phone=False, + use_gr_residual_f0=False, + use_gr_residual_phone=False, + use_gr_x_timbre=False, + use_random_mask_residual=True, + prob_random_mask_residual=0.75, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf = ngf + self.up_ratios = up_ratios + + self.use_random_mask_residual = use_random_mask_residual + self.prob_random_mask_residual = prob_random_mask_residual + + self.vq_num_q_p = vq_num_q_p + self.vq_num_q_c = vq_num_q_c + self.vq_num_q_r = vq_num_q_r + + self.codebook_size_prosody = codebook_size_prosody + self.codebook_size_content = codebook_size_content + self.codebook_size_residual = codebook_size_residual + + quantizer_class = ResidualVQ + + self.quantizer = nn.ModuleList() + + # prosody + quantizer = quantizer_class( + num_quantizers=vq_num_q_p, + dim=vq_dim, + codebook_size=codebook_size_prosody, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + quantizer_dropout=quantizer_dropout, + dropout_type=dropout_type, + ) + self.quantizer.append(quantizer) + + # phone + quantizer = quantizer_class( + num_quantizers=vq_num_q_c, + dim=vq_dim, + codebook_size=codebook_size_content, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + quantizer_dropout=quantizer_dropout, + dropout_type=dropout_type, + ) + self.quantizer.append(quantizer) + + # residual + if self.vq_num_q_r > 0: + quantizer = quantizer_class( + num_quantizers=vq_num_q_r, + dim=vq_dim, + codebook_size=codebook_size_residual, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + quantizer_dropout=quantizer_dropout, + dropout_type=dropout_type, + ) + self.quantizer.append(quantizer) + + # Add first conv layer + channels = upsample_initial_channel + layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(up_ratios): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride)] + + # Add final conv layer + layers += [ + Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)), + WNConv1d(output_dim, 1, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + self.timbre_encoder = TransformerEncoder( + enc_emb_tokens=None, + encoder_layer=4, + encoder_hidden=256, + encoder_head=4, + conv_filter_size=1024, + conv_kernel_size=5, + encoder_dropout=0.1, + use_cln=False, + ) + + self.timbre_linear = nn.Linear(in_channels, in_channels * 2) + self.timbre_linear.bias.data[:in_channels] = 1 + self.timbre_linear.bias.data[in_channels:] = 0 + self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False) + + self.f0_predictor = CNNLSTM(in_channels, 1, 2) + self.phone_predictor = CNNLSTM(in_channels, 5003, 1) + + self.use_gr_content_f0 = use_gr_content_f0 + self.use_gr_prosody_phone = use_gr_prosody_phone + self.use_gr_residual_f0 = use_gr_residual_f0 + self.use_gr_residual_phone = use_gr_residual_phone + self.use_gr_x_timbre = use_gr_x_timbre + + if self.vq_num_q_r > 0 and self.use_gr_residual_f0: + self.res_f0_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2) + ) + + if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0: + self.res_phone_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1) + ) + + if self.use_gr_content_f0: + self.content_f0_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2) + ) + + if self.use_gr_prosody_phone: + self.prosody_phone_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1) + ) + + if self.use_gr_x_timbre: + self.x_timbre_predictor = nn.Sequential( + GradientReversal(alpha=1), + CNNLSTM(in_channels, 245200, 1, global_pred=True), + ) + + self.reset_parameters() + + def quantize(self, x, n_quantizers=None): + outs, qs, commit_loss, quantized_buf = 0, [], [], [] + + # prosody + f0_input = x # (B, d, T) + f0_quantizer = self.quantizer[0] + out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers) + outs += out + qs.append(q) + quantized_buf.append(quantized.sum(0)) + commit_loss.append(commit) + + # phone + phone_input = x + phone_quantizer = self.quantizer[1] + out, q, commit, quantized = phone_quantizer( + phone_input, n_quantizers=n_quantizers + ) + outs += out + qs.append(q) + quantized_buf.append(quantized.sum(0)) + commit_loss.append(commit) + + # residual + if self.vq_num_q_r > 0: + residual_quantizer = self.quantizer[2] + residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach() + out, q, commit, quantized = residual_quantizer( + residual_input, n_quantizers=n_quantizers + ) + outs += out + qs.append(q) + quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T] + commit_loss.append(commit) + + qs = torch.cat(qs, dim=0) + commit_loss = torch.cat(commit_loss, dim=0) + return outs, qs, commit_loss, quantized_buf + + def forward( + self, + x, + vq=True, + get_vq=False, + eval_vq=True, + speaker_embedding=None, + n_quantizers=None, + quantized=None, + ): + if get_vq: + return self.quantizer.get_emb() + if vq is True: + if eval_vq: + self.quantizer.eval() + x_timbre = x + outs, qs, commit_loss, quantized_buf = self.quantize( + x, n_quantizers=n_quantizers + ) + + x_timbre = x_timbre.transpose(1, 2) + x_timbre = self.timbre_encoder(x_timbre, None, None) + x_timbre = x_timbre.transpose(1, 2) + spk_embs = torch.mean(x_timbre, dim=2) + return outs, qs, commit_loss, quantized_buf, spk_embs + + out = {} + + layer_0 = quantized[0] + f0, uv = self.f0_predictor(layer_0) + f0 = rearrange(f0, "... 1 -> ...") + uv = rearrange(uv, "... 1 -> ...") + + layer_1 = quantized[1] + (phone,) = self.phone_predictor(layer_1) + + out = {"f0": f0, "uv": uv, "phone": phone} + + if self.use_gr_prosody_phone: + (prosody_phone,) = self.prosody_phone_predictor(layer_0) + out["prosody_phone"] = prosody_phone + + if self.use_gr_content_f0: + content_f0, content_uv = self.content_f0_predictor(layer_1) + content_f0 = rearrange(content_f0, "... 1 -> ...") + content_uv = rearrange(content_uv, "... 1 -> ...") + out["content_f0"] = content_f0 + out["content_uv"] = content_uv + + if self.vq_num_q_r > 0: + layer_2 = quantized[2] + + if self.use_gr_residual_f0: + res_f0, res_uv = self.res_f0_predictor(layer_2) + res_f0 = rearrange(res_f0, "... 1 -> ...") + res_uv = rearrange(res_uv, "... 1 -> ...") + out["res_f0"] = res_f0 + out["res_uv"] = res_uv + + if self.use_gr_residual_phone: + (res_phone,) = self.res_phone_predictor(layer_2) + out["res_phone"] = res_phone + + style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1) + gamma, beta = style.chunk(2, 1) # (B, d, 1) + if self.vq_num_q_r > 0: + if self.use_random_mask_residual: + bsz = quantized[2].shape[0] + res_mask = np.random.choice( + [0, 1], + size=bsz, + p=[ + self.prob_random_mask_residual, + 1 - self.prob_random_mask_residual, + ], + ) + res_mask = ( + torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) + ) # (B, 1, 1) + res_mask = res_mask.to( + device=quantized[2].device, dtype=quantized[2].dtype + ) + x = ( + quantized[0].detach() + + quantized[1].detach() + + quantized[2] * res_mask + ) + # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask + else: + x = quantized[0].detach() + quantized[1].detach() + quantized[2] + # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] + else: + x = quantized[0].detach() + quantized[1].detach() + # x = quantized_perturbe[0].detach() + quantized[1].detach() + + if self.use_gr_x_timbre: + (x_timbre,) = self.x_timbre_predictor(x) + out["x_timbre"] = x_timbre + + x = x.transpose(1, 2) + x = self.timbre_norm(x) + x = x.transpose(1, 2) + x = x * gamma + beta + + x = self.model(x) + out["audio"] = x + + return out + + def vq2emb(self, vq, use_residual_code=True): + # vq: [num_quantizer, B, T] + self.quantizer = self.quantizer.eval() + out = 0 + out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p]) + out += self.quantizer[1].vq2emb( + vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c] + ) + if self.vq_num_q_r > 0 and use_residual_code: + out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :]) + return out + + def inference(self, x, speaker_embedding): + style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1) + gamma, beta = style.chunk(2, 1) # (B, d, 1) + x = x.transpose(1, 2) + x = self.timbre_norm(x) + x = x.transpose(1, 2) + x = x * gamma + beta + x = self.model(x) + return x + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + +class FACodecRedecoder(nn.Module): + def __init__( + self, + in_channels=256, + upsample_initial_channel=1280, + up_ratios=(5, 5, 4, 2), + vq_num_q_c=2, + vq_num_q_p=1, + vq_num_q_r=3, + vq_dim=256, + codebook_size_prosody=10, + codebook_size_content=10, + codebook_size_residual=10, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.up_ratios = up_ratios + + self.vq_num_q_p = vq_num_q_p + self.vq_num_q_c = vq_num_q_c + self.vq_num_q_r = vq_num_q_r + + self.vq_dim = vq_dim + + self.codebook_size_prosody = codebook_size_prosody + self.codebook_size_content = codebook_size_content + self.codebook_size_residual = codebook_size_residual + + self.prosody_embs = nn.ModuleList() + for i in range(self.vq_num_q_p): + emb_tokens = nn.Embedding( + num_embeddings=2**self.codebook_size_prosody, + embedding_dim=self.vq_dim, + ) + emb_tokens.weight.data.normal_(mean=0.0, std=1e-5) + self.prosody_embs.append(emb_tokens) + self.content_embs = nn.ModuleList() + for i in range(self.vq_num_q_c): + emb_tokens = nn.Embedding( + num_embeddings=2**self.codebook_size_content, + embedding_dim=self.vq_dim, + ) + emb_tokens.weight.data.normal_(mean=0.0, std=1e-5) + self.content_embs.append(emb_tokens) + self.residual_embs = nn.ModuleList() + for i in range(self.vq_num_q_r): + emb_tokens = nn.Embedding( + num_embeddings=2**self.codebook_size_residual, + embedding_dim=self.vq_dim, + ) + emb_tokens.weight.data.normal_(mean=0.0, std=1e-5) + self.residual_embs.append(emb_tokens) + + # Add first conv layer + channels = upsample_initial_channel + layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(up_ratios): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride)] + + # Add final conv layer + layers += [ + Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)), + WNConv1d(output_dim, 1, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + self.timbre_linear = nn.Linear(in_channels, in_channels * 2) + self.timbre_linear.bias.data[:in_channels] = 1 + self.timbre_linear.bias.data[in_channels:] = 0 + self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False) + + self.timbre_cond_prosody_enc = TransformerEncoder( + enc_emb_tokens=None, + encoder_layer=4, + encoder_hidden=256, + encoder_head=4, + conv_filter_size=1024, + conv_kernel_size=5, + encoder_dropout=0.1, + use_cln=True, + cfg=None, + ) + + def forward( + self, + vq, + speaker_embedding, + use_residual_code=False, + ): + + x = 0 + + x_p = 0 + for i in range(self.vq_num_q_p): + x_p = x_p + self.prosody_embs[i](vq[i]) # (B, T, d) + spk_cond = speaker_embedding.unsqueeze(1).expand(-1, x_p.shape[1], -1) + x_p = self.timbre_cond_prosody_enc( + x_p, key_padding_mask=None, condition=spk_cond + ) + x = x + x_p + + x_c = 0 + for i in range(self.vq_num_q_c): + x_c = x_c + self.content_embs[i](vq[self.vq_num_q_p + i]) + + x = x + x_c + + if use_residual_code: + + x_r = 0 + for i in range(self.vq_num_q_r): + x_r = x_r + self.residual_embs[i]( + vq[self.vq_num_q_p + self.vq_num_q_c + i] + ) + x = x + x_r + + style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1) + gamma, beta = style.chunk(2, 1) # (B, d, 1) + x = x.transpose(1, 2) + x = self.timbre_norm(x) + x = x.transpose(1, 2) + x = x * gamma + beta + x = self.model(x) + + return x + + def vq2emb(self, vq, speaker_embedding, use_residual=True): + + out = 0 + + x_t = 0 + for i in range(self.vq_num_q_p): + x_t += self.prosody_embs[i](vq[i]) # (B, T, d) + spk_cond = speaker_embedding.unsqueeze(1).expand(-1, x_t.shape[1], -1) + x_t = self.timbre_cond_prosody_enc( + x_t, key_padding_mask=None, condition=spk_cond + ) + + # prosody + out += x_t + + # content + for i in range(self.vq_num_q_c): + out += self.content_embs[i](vq[self.vq_num_q_p + i]) + + # residual + if use_residual: + for i in range(self.vq_num_q_r): + out += self.residual_embs[i](vq[self.vq_num_q_p + self.vq_num_q_c + i]) + + out = out.transpose(1, 2) # (B, T, d) -> (B, d, T) + return out + + def inference(self, x, speaker_embedding): + style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1) + gamma, beta = style.chunk(2, 1) # (B, d, 1) + x = x.transpose(1, 2) + x = self.timbre_norm(x) + x = x.transpose(1, 2) + x = x * gamma + beta + x = self.model(x) + return x + + +class FACodecEncoderV2(nn.Module): + def __init__( + self, + ngf=32, + up_ratios=(2, 4, 5, 5), + out_channels=1024, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.up_ratios = up_ratios + + # Create first convolution + d_model = ngf + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in up_ratios: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)), + WNConv1d(d_model, out_channels, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + self.mel_transform = MelSpectrogram( + n_fft=1024, + num_mels=80, + sampling_rate=16000, + hop_size=200, + win_size=800, + fmin=0, + fmax=8000, + ) + + self.reset_parameters() + + def forward(self, x): + out = self.block(x) + return out + + def inference(self, x): + return self.block(x) + + def get_prosody_feature(self, x): + return self.mel_transform(x.squeeze(1))[:, :20, :] + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + +class FACodecDecoderV2(nn.Module): + def __init__( + self, + in_channels=256, + upsample_initial_channel=1536, + ngf=32, + up_ratios=(5, 5, 4, 2), + vq_num_q_c=2, + vq_num_q_p=1, + vq_num_q_r=3, + vq_dim=1024, + vq_commit_weight=0.005, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_dim=8, + codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size + codebook_size_content=10, + codebook_size_residual=10, + quantizer_dropout=0.0, + dropout_type="linear", + use_gr_content_f0=False, + use_gr_prosody_phone=False, + use_gr_residual_f0=False, + use_gr_residual_phone=False, + use_gr_x_timbre=False, + use_random_mask_residual=True, + prob_random_mask_residual=0.75, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf = ngf + self.up_ratios = up_ratios + + self.use_random_mask_residual = use_random_mask_residual + self.prob_random_mask_residual = prob_random_mask_residual + + self.vq_num_q_p = vq_num_q_p + self.vq_num_q_c = vq_num_q_c + self.vq_num_q_r = vq_num_q_r + + self.codebook_size_prosody = codebook_size_prosody + self.codebook_size_content = codebook_size_content + self.codebook_size_residual = codebook_size_residual + + quantizer_class = ResidualVQ + + self.quantizer = nn.ModuleList() + + # prosody + quantizer = quantizer_class( + num_quantizers=vq_num_q_p, + dim=vq_dim, + codebook_size=codebook_size_prosody, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + quantizer_dropout=quantizer_dropout, + dropout_type=dropout_type, + ) + self.quantizer.append(quantizer) + + # phone + quantizer = quantizer_class( + num_quantizers=vq_num_q_c, + dim=vq_dim, + codebook_size=codebook_size_content, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + quantizer_dropout=quantizer_dropout, + dropout_type=dropout_type, + ) + self.quantizer.append(quantizer) + + # residual + if self.vq_num_q_r > 0: + quantizer = quantizer_class( + num_quantizers=vq_num_q_r, + dim=vq_dim, + codebook_size=codebook_size_residual, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + quantizer_dropout=quantizer_dropout, + dropout_type=dropout_type, + ) + self.quantizer.append(quantizer) + + # Add first conv layer + channels = upsample_initial_channel + layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(up_ratios): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride)] + + # Add final conv layer + layers += [ + Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)), + WNConv1d(output_dim, 1, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + self.timbre_encoder = TransformerEncoder( + enc_emb_tokens=None, + encoder_layer=4, + encoder_hidden=256, + encoder_head=4, + conv_filter_size=1024, + conv_kernel_size=5, + encoder_dropout=0.1, + use_cln=False, + ) + + self.timbre_linear = nn.Linear(in_channels, in_channels * 2) + self.timbre_linear.bias.data[:in_channels] = 1 + self.timbre_linear.bias.data[in_channels:] = 0 + self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False) + + self.f0_predictor = CNNLSTM(in_channels, 1, 2) + self.phone_predictor = CNNLSTM(in_channels, 5003, 1) + + self.use_gr_content_f0 = use_gr_content_f0 + self.use_gr_prosody_phone = use_gr_prosody_phone + self.use_gr_residual_f0 = use_gr_residual_f0 + self.use_gr_residual_phone = use_gr_residual_phone + self.use_gr_x_timbre = use_gr_x_timbre + + if self.vq_num_q_r > 0 and self.use_gr_residual_f0: + self.res_f0_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2) + ) + + if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0: + self.res_phone_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1) + ) + + if self.use_gr_content_f0: + self.content_f0_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2) + ) + + if self.use_gr_prosody_phone: + self.prosody_phone_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1) + ) + + if self.use_gr_x_timbre: + self.x_timbre_predictor = nn.Sequential( + GradientReversal(alpha=1), + CNNLSTM(in_channels, 245200, 1, global_pred=True), + ) + + self.melspec_linear = nn.Linear(20, 256) + self.melspec_encoder = TransformerEncoder( + enc_emb_tokens=None, + encoder_layer=4, + encoder_hidden=256, + encoder_head=4, + conv_filter_size=1024, + conv_kernel_size=5, + encoder_dropout=0.1, + use_cln=False, + cfg=None, + ) + + self.reset_parameters() + + def quantize(self, x, prosody_feature, n_quantizers=None): + outs, qs, commit_loss, quantized_buf = 0, [], [], [] + + # prosody + f0_input = prosody_feature.transpose(1, 2) # (B, T, 20) + f0_input = self.melspec_linear(f0_input) + f0_input = self.melspec_encoder(f0_input, None, None) + f0_input = f0_input.transpose(1, 2) + f0_quantizer = self.quantizer[0] + out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers) + outs += out + qs.append(q) + quantized_buf.append(quantized.sum(0)) + commit_loss.append(commit) + + # phone + phone_input = x + phone_quantizer = self.quantizer[1] + out, q, commit, quantized = phone_quantizer( + phone_input, n_quantizers=n_quantizers + ) + outs += out + qs.append(q) + quantized_buf.append(quantized.sum(0)) + commit_loss.append(commit) + + # residual + if self.vq_num_q_r > 0: + residual_quantizer = self.quantizer[2] + residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach() + out, q, commit, quantized = residual_quantizer( + residual_input, n_quantizers=n_quantizers + ) + outs += out + qs.append(q) + quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T] + commit_loss.append(commit) + + qs = torch.cat(qs, dim=0) + commit_loss = torch.cat(commit_loss, dim=0) + return outs, qs, commit_loss, quantized_buf + + def forward( + self, + x, + prosody_feature, + vq=True, + get_vq=False, + eval_vq=True, + speaker_embedding=None, + n_quantizers=None, + quantized=None, + ): + if get_vq: + return self.quantizer.get_emb() + if vq is True: + if eval_vq: + self.quantizer.eval() + x_timbre = x + outs, qs, commit_loss, quantized_buf = self.quantize( + x, prosody_feature, n_quantizers=n_quantizers + ) + + x_timbre = x_timbre.transpose(1, 2) + x_timbre = self.timbre_encoder(x_timbre, None, None) + x_timbre = x_timbre.transpose(1, 2) + spk_embs = torch.mean(x_timbre, dim=2) + return outs, qs, commit_loss, quantized_buf, spk_embs + + out = {} + + layer_0 = quantized[0] + f0, uv = self.f0_predictor(layer_0) + f0 = rearrange(f0, "... 1 -> ...") + uv = rearrange(uv, "... 1 -> ...") + + layer_1 = quantized[1] + (phone,) = self.phone_predictor(layer_1) + + out = {"f0": f0, "uv": uv, "phone": phone} + + if self.use_gr_prosody_phone: + (prosody_phone,) = self.prosody_phone_predictor(layer_0) + out["prosody_phone"] = prosody_phone + + if self.use_gr_content_f0: + content_f0, content_uv = self.content_f0_predictor(layer_1) + content_f0 = rearrange(content_f0, "... 1 -> ...") + content_uv = rearrange(content_uv, "... 1 -> ...") + out["content_f0"] = content_f0 + out["content_uv"] = content_uv + + if self.vq_num_q_r > 0: + layer_2 = quantized[2] + + if self.use_gr_residual_f0: + res_f0, res_uv = self.res_f0_predictor(layer_2) + res_f0 = rearrange(res_f0, "... 1 -> ...") + res_uv = rearrange(res_uv, "... 1 -> ...") + out["res_f0"] = res_f0 + out["res_uv"] = res_uv + + if self.use_gr_residual_phone: + (res_phone,) = self.res_phone_predictor(layer_2) + out["res_phone"] = res_phone + + style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1) + gamma, beta = style.chunk(2, 1) # (B, d, 1) + if self.vq_num_q_r > 0: + if self.use_random_mask_residual: + bsz = quantized[2].shape[0] + res_mask = np.random.choice( + [0, 1], + size=bsz, + p=[ + self.prob_random_mask_residual, + 1 - self.prob_random_mask_residual, + ], + ) + res_mask = ( + torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) + ) # (B, 1, 1) + res_mask = res_mask.to( + device=quantized[2].device, dtype=quantized[2].dtype + ) + x = ( + quantized[0].detach() + + quantized[1].detach() + + quantized[2] * res_mask + ) + # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask + else: + x = quantized[0].detach() + quantized[1].detach() + quantized[2] + # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] + else: + x = quantized[0].detach() + quantized[1].detach() + # x = quantized_perturbe[0].detach() + quantized[1].detach() + + if self.use_gr_x_timbre: + (x_timbre,) = self.x_timbre_predictor(x) + out["x_timbre"] = x_timbre + + x = x.transpose(1, 2) + x = self.timbre_norm(x) + x = x.transpose(1, 2) + x = x * gamma + beta + + x = self.model(x) + out["audio"] = x + + return out + + def vq2emb(self, vq, use_residual=True): + # vq: [num_quantizer, B, T] + self.quantizer = self.quantizer.eval() + out = 0 + out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p]) + out += self.quantizer[1].vq2emb( + vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c] + ) + if self.vq_num_q_r > 0 and use_residual: + out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :]) + return out + + def inference(self, x, speaker_embedding): + style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1) + gamma, beta = style.chunk(2, 1) # (B, d, 1) + x = x.transpose(1, 2) + x = self.timbre_norm(x) + x = x.transpose(1, 2) + x = x * gamma + beta + x = self.model(x) + return x + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) diff --git a/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py b/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py new file mode 100644 index 0000000000000000000000000000000000000000..d09396ea20c653b2a443e144ab429f534ce033fd --- /dev/null +++ b/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch.autograd import Function +import torch +from torch import nn + + +class GradientReversal(Function): + @staticmethod + def forward(ctx, x, alpha): + ctx.save_for_backward(x, alpha) + return x + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + _, alpha = ctx.saved_tensors + if ctx.needs_input_grad[0]: + grad_input = -alpha * grad_output + return grad_input, None + + +revgrad = GradientReversal.apply + + +class GradientReversal(nn.Module): + def __init__(self, alpha): + super().__init__() + self.alpha = torch.tensor(alpha, requires_grad=False) + + def forward(self, x): + return revgrad(x, self.alpha) diff --git a/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py b/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py new file mode 100644 index 0000000000000000000000000000000000000000..cbf1cd24ebe533477be0833723b84b0b1d75c2d8 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py @@ -0,0 +1,102 @@ +import torch +import pyworld as pw +import numpy as np +import soundfile as sf +import os +from torchaudio.functional import pitch_shift +import librosa +from librosa.filters import mel as librosa_mel_fn +import torch.nn as nn +import torch.nn.functional as F + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +class MelSpectrogram(nn.Module): + def __init__( + self, + n_fft, + num_mels, + sampling_rate, + hop_size, + win_size, + fmin, + fmax, + center=False, + ): + super(MelSpectrogram, self).__init__() + self.n_fft = n_fft + self.hop_size = hop_size + self.win_size = win_size + self.sampling_rate = sampling_rate + self.num_mels = num_mels + self.fmin = fmin + self.fmax = fmax + self.center = center + + mel_basis = {} + hann_window = {} + + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis = torch.from_numpy(mel).float() + hann_window = torch.hann_window(win_size) + + self.register_buffer("mel_basis", mel_basis) + self.register_buffer("hann_window", hann_window) + + def forward(self, y): + y = torch.nn.functional.pad( + y.unsqueeze(1), + ( + int((self.n_fft - self.hop_size) / 2), + int((self.n_fft - self.hop_size) / 2), + ), + mode="reflect", + ) + y = y.squeeze(1) + spec = torch.stft( + y, + self.n_fft, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=self.center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(self.mel_basis, spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py b/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb7b4063ca2364ccc2658a8e19061fb65ddd7a7 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .fvq import * +from .rvq import * diff --git a/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py b/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py new file mode 100644 index 0000000000000000000000000000000000000000..4ade35d889497a8d42034b6cf00aea48c92c5422 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +class FactorizedVectorQuantize(nn.Module): + def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.commitment = commitment + + if dim != self.codebook_dim: + self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim)) + self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim)) + else: + self.in_proj = nn.Identity() + self.out_proj = nn.Identity() + self._codebook = nn.Embedding(codebook_size, self.codebook_dim) + + @property + def codebook(self): + return self._codebook + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + # transpose since we use linear + + z = rearrange(z, "b d t -> b t d") + + # Factorized codes project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x T x D) + z_e = rearrange(z_e, "b t d -> b d t") + z_q, indices = self.decode_latents(z_e) + + if self.training: + commitment_loss = ( + F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + * self.commitment + ) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + commit_loss = commitment_loss + codebook_loss + else: + commit_loss = torch.zeros(z.shape[0], device=z.device) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = rearrange(z_q, "b d t -> b t d") + z_q = self.out_proj(z_q) + z_q = rearrange(z_q, "b t d -> b d t") + + return z_q, indices, commit_loss + + def vq2emb(self, vq, proj=True): + emb = self.embed_code(vq) + if proj: + emb = self.out_proj(emb) + return emb.transpose(1, 2) + + def get_emb(self): + return self.codebook.weight + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + # L2 normalize encodings and codebook + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices diff --git a/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py b/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py new file mode 100644 index 0000000000000000000000000000000000000000..d22d88d584df625234d865a63e0fdb709fdf77a2 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import torch +from torch import nn +from .fvq import FactorizedVectorQuantize + + +class ResidualVQ(nn.Module): + """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" + + def __init__(self, *, num_quantizers, codebook_size, **kwargs): + super().__init__() + VQ = FactorizedVectorQuantize + if type(codebook_size) == int: + codebook_size = [codebook_size] * num_quantizers + self.layers = nn.ModuleList( + [VQ(codebook_size=2**size, **kwargs) for size in codebook_size] + ) + self.num_quantizers = num_quantizers + self.quantizer_dropout = kwargs.get("quantizer_dropout", 0.0) + self.dropout_type = kwargs.get("dropout_type", None) + + def forward(self, x, n_quantizers=None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + all_quantized = [] + + if n_quantizers is None: + n_quantizers = self.num_quantizers + if self.training: + n_quantizers = torch.ones((x.shape[0],)) * self.num_quantizers + 1 + if self.dropout_type == "linear": + dropout = torch.randint(1, self.num_quantizers + 1, (x.shape[0],)) + elif self.dropout_type == "exp": + dropout = torch.randint( + 1, int(math.log2(self.num_quantizers)), (x.shape[0],) + ) + dropout = torch.pow(2, dropout) + n_dropout = int(x.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(x.device) + + for idx, layer in enumerate(self.layers): + if not self.training and idx >= n_quantizers: + break + quantized, indices, loss = layer(residual) + + mask = ( + torch.full((x.shape[0],), fill_value=idx, device=x.device) + < n_quantizers + ) + + residual = residual - quantized + + quantized_out = quantized_out + quantized * mask[:, None, None] + + # loss + loss = (loss * mask).mean() + + all_indices.append(indices) + all_losses.append(loss) + all_quantized.append(quantized) + all_losses, all_indices, all_quantized = map( + torch.stack, (all_losses, all_indices, all_quantized) + ) + return quantized_out, all_indices, all_losses, all_quantized + + def vq2emb(self, vq): + # vq: [n_quantizers, B, T] + quantized_out = 0.0 + for idx, layer in enumerate(self.layers): + quantized = layer.vq2emb(vq[idx]) + quantized_out += quantized + return quantized_out + + def get_emb(self): + embs = [] + for idx, layer in enumerate(self.layers): + embs.append(layer.get_emb()) + return embs diff --git a/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py b/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..146d0f364dd17c385babb2c903f33378038556db --- /dev/null +++ b/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py @@ -0,0 +1,234 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn as nn +import math +from torch.nn import functional as F + + +class StyleAdaptiveLayerNorm(nn.Module): + def __init__(self, normalized_shape, eps=1e-5): + super().__init__() + self.in_dim = normalized_shape + self.norm = nn.LayerNorm(self.in_dim, eps=eps, elementwise_affine=False) + self.style = nn.Linear(self.in_dim, self.in_dim * 2) + self.style.bias.data[: self.in_dim] = 1 + self.style.bias.data[self.in_dim :] = 0 + + def forward(self, x, condition): + # x: (B, T, d); condition: (B, T, d) + + style = self.style(torch.mean(condition, dim=1, keepdim=True)) + + gamma, beta = style.chunk(2, -1) + + out = self.norm(x) + + out = gamma * out + beta + return out + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout, max_len=5000): + super().__init__() + + self.dropout = dropout + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + self.pe[: x.size(0)] + return F.dropout(x, self.dropout, training=self.training) + + +class TransformerFFNLayer(nn.Module): + def __init__( + self, encoder_hidden, conv_filter_size, conv_kernel_size, encoder_dropout + ): + super().__init__() + + self.encoder_hidden = encoder_hidden + self.conv_filter_size = conv_filter_size + self.conv_kernel_size = conv_kernel_size + self.encoder_dropout = encoder_dropout + + self.ffn_1 = nn.Conv1d( + self.encoder_hidden, + self.conv_filter_size, + self.conv_kernel_size, + padding=self.conv_kernel_size // 2, + ) + self.ffn_1.weight.data.normal_(0.0, 0.02) + self.ffn_2 = nn.Linear(self.conv_filter_size, self.encoder_hidden) + self.ffn_2.weight.data.normal_(0.0, 0.02) + + def forward(self, x): + # x: (B, T, d) + x = self.ffn_1(x.permute(0, 2, 1)).permute( + 0, 2, 1 + ) # (B, T, d) -> (B, d, T) -> (B, T, d) + x = F.relu(x) + x = F.dropout(x, self.encoder_dropout, training=self.training) + x = self.ffn_2(x) + return x + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + encoder_hidden, + encoder_head, + conv_filter_size, + conv_kernel_size, + encoder_dropout, + use_cln, + ): + super().__init__() + self.encoder_hidden = encoder_hidden + self.encoder_head = encoder_head + self.conv_filter_size = conv_filter_size + self.conv_kernel_size = conv_kernel_size + self.encoder_dropout = encoder_dropout + self.use_cln = use_cln + + if not self.use_cln: + self.ln_1 = nn.LayerNorm(self.encoder_hidden) + self.ln_2 = nn.LayerNorm(self.encoder_hidden) + else: + self.ln_1 = StyleAdaptiveLayerNorm(self.encoder_hidden) + self.ln_2 = StyleAdaptiveLayerNorm(self.encoder_hidden) + + self.self_attn = nn.MultiheadAttention( + self.encoder_hidden, self.encoder_head, batch_first=True + ) + + self.ffn = TransformerFFNLayer( + self.encoder_hidden, + self.conv_filter_size, + self.conv_kernel_size, + self.encoder_dropout, + ) + + def forward(self, x, key_padding_mask, conditon=None): + # x: (B, T, d); key_padding_mask: (B, T), mask is 0; condition: (B, T, d) + + # self attention + residual = x + if self.use_cln: + x = self.ln_1(x, conditon) + else: + x = self.ln_1(x) + + if key_padding_mask != None: + key_padding_mask_input = ~(key_padding_mask.bool()) + else: + key_padding_mask_input = None + x, _ = self.self_attn( + query=x, key=x, value=x, key_padding_mask=key_padding_mask_input + ) + x = F.dropout(x, self.encoder_dropout, training=self.training) + x = residual + x + + # ffn + residual = x + if self.use_cln: + x = self.ln_2(x, conditon) + else: + x = self.ln_2(x) + x = self.ffn(x) + x = residual + x + + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, + enc_emb_tokens=None, + encoder_layer=4, + encoder_hidden=256, + encoder_head=4, + conv_filter_size=1024, + conv_kernel_size=5, + encoder_dropout=0.1, + use_cln=False, + cfg=None, + ): + super().__init__() + + self.encoder_layer = ( + encoder_layer if encoder_layer is not None else cfg.encoder_layer + ) + self.encoder_hidden = ( + encoder_hidden if encoder_hidden is not None else cfg.encoder_hidden + ) + self.encoder_head = ( + encoder_head if encoder_head is not None else cfg.encoder_head + ) + self.conv_filter_size = ( + conv_filter_size if conv_filter_size is not None else cfg.conv_filter_size + ) + self.conv_kernel_size = ( + conv_kernel_size if conv_kernel_size is not None else cfg.conv_kernel_size + ) + self.encoder_dropout = ( + encoder_dropout if encoder_dropout is not None else cfg.encoder_dropout + ) + self.use_cln = use_cln if use_cln is not None else cfg.use_cln + + if enc_emb_tokens != None: + self.use_enc_emb = True + self.enc_emb_tokens = enc_emb_tokens + else: + self.use_enc_emb = False + + self.position_emb = PositionalEncoding( + self.encoder_hidden, self.encoder_dropout + ) + + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + TransformerEncoderLayer( + self.encoder_hidden, + self.encoder_head, + self.conv_filter_size, + self.conv_kernel_size, + self.encoder_dropout, + self.use_cln, + ) + for i in range(self.encoder_layer) + ] + ) + + if self.use_cln: + self.last_ln = StyleAdaptiveLayerNorm(self.encoder_hidden) + else: + self.last_ln = nn.LayerNorm(self.encoder_hidden) + + def forward(self, x, key_padding_mask, condition=None): + if len(x.shape) == 2 and self.use_enc_emb: + x = self.enc_emb_tokens(x) + x = self.position_emb(x) + else: + x = self.position_emb(x) # (B, T, d) + + for layer in self.layers: + x = layer(x, key_padding_mask, condition) + + if self.use_cln: + x = self.last_ln(x, condition) + else: + x = self.last_ln(x) + + return x diff --git a/indextts/utils/maskgct/models/codec/speechtokenizer/model.py b/indextts/utils/maskgct/models/codec/speechtokenizer/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b722d38653acdb348a40b1bd8ff3b94ceb2db563 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/speechtokenizer/model.py @@ -0,0 +1,184 @@ +# Copyright (c) 2023 Amphion. +# +# This code is modified from https://github.com/ZhangXInFD/SpeechTokenizer/blob/main/speechtokenizer/model.py +# Licensed under Apache License 2.0 + +from .modules.seanet import SEANetEncoder, SEANetDecoder +from .modules.quantization import ResidualVectorQuantizer +import torch.nn as nn +from einops import rearrange +import torch +import numpy as np + + +class SpeechTokenizer(nn.Module): + def __init__(self, config): + """ + + Parameters + ---------- + config : json + Model Config. + + """ + super().__init__() + self.encoder = SEANetEncoder( + n_filters=config.get("n_filters"), + dimension=config.get("dimension"), + ratios=config.get("strides"), + lstm=config.get("lstm_layers"), + bidirectional=config.get("bidirectional"), + dilation_base=config.get("dilation_base"), + residual_kernel_size=config.get("residual_kernel_size"), + n_residual_layers=config.get("n_residual_layers"), + activation=config.get("activation"), + ) + self.sample_rate = config.get("sample_rate") + self.n_q = config.get("n_q") + self.downsample_rate = np.prod(config.get("strides")) + if config.get("dimension") != config.get("semantic_dimension"): + self.transform = nn.Linear( + config.get("dimension"), config.get("semantic_dimension") + ) + else: + self.transform = nn.Identity() + self.quantizer = ResidualVectorQuantizer( + dimension=config.get("dimension"), + n_q=config.get("n_q"), + bins=config.get("codebook_size"), + ) + self.decoder = SEANetDecoder( + n_filters=config.get("n_filters"), + dimension=config.get("dimension"), + ratios=config.get("strides"), + lstm=config.get("lstm_layers"), + bidirectional=False, + dilation_base=config.get("dilation_base"), + residual_kernel_size=config.get("residual_kernel_size"), + n_residual_layers=config.get("n_residual_layers"), + activation=config.get("activation"), + ) + + @classmethod + def load_from_checkpoint(cls, config_path: str, ckpt_path: str): + """ + + Parameters + ---------- + config_path : str + Path of model configuration file. + ckpt_path : str + Path of model checkpoint. + + Returns + ------- + model : SpeechTokenizer + SpeechTokenizer model. + + """ + import json + + with open(config_path) as f: + cfg = json.load(f) + model = cls(cfg) + params = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(params) + return model + + def forward(self, x: torch.tensor, n_q: int = None, layers: list = [0]): + """ + + Parameters + ---------- + x : torch.tensor + Input wavs. Shape: (batch, channels, timesteps). + n_q : int, optional + Number of quantizers in RVQ used to encode. The default is all layers. + layers : list[int], optional + Layers of RVQ should return quantized result. The default is the first layer. + + Returns + ------- + o : torch.tensor + Output wavs. Shape: (batch, channels, timesteps). + commit_loss : torch.tensor + Commitment loss from residual vector quantizers. + feature : torch.tensor + Output of RVQ's first layer. Shape: (batch, timesteps, dimension) + + """ + n_q = n_q if n_q else self.n_q + e = self.encoder(x) + quantized, codes, commit_loss, quantized_list = self.quantizer( + e, n_q=n_q, layers=layers + ) + feature = rearrange(quantized_list[0], "b d t -> b t d") + feature = self.transform(feature) + o = self.decoder(quantized) + return o, commit_loss, feature + + def forward_feature(self, x: torch.tensor, layers: list = None): + """ + + Parameters + ---------- + x : torch.tensor + Input wavs. Shape should be (batch, channels, timesteps). + layers : list[int], optional + Layers of RVQ should return quantized result. The default is all layers. + + Returns + ------- + quantized_list : list[torch.tensor] + Quantized of required layers. + + """ + e = self.encoder(x) + layers = layers if layers else list(range(self.n_q)) + quantized, codes, commit_loss, quantized_list = self.quantizer(e, layers=layers) + return quantized_list + + def encode(self, x: torch.tensor, n_q: int = None, st: int = None): + """ + + Parameters + ---------- + x : torch.tensor + Input wavs. Shape: (batch, channels, timesteps). + n_q : int, optional + Number of quantizers in RVQ used to encode. The default is all layers. + st : int, optional + Start quantizer index in RVQ. The default is 0. + + Returns + ------- + codes : torch.tensor + Output indices for each quantizer. Shape: (n_q, batch, timesteps) + + """ + e = self.encoder(x) + if st is None: + st = 0 + n_q = n_q if n_q else self.n_q + codes = self.quantizer.encode(e, n_q=n_q, st=st) + return codes + + def decode(self, codes: torch.tensor, st: int = 0): + """ + + Parameters + ---------- + codes : torch.tensor + Indices for each quantizer. Shape: (n_q, batch, timesteps). + st : int, optional + Start quantizer index in RVQ. The default is 0. + + Returns + ------- + o : torch.tensor + Reconstruct wavs from codes. Shape: (batch, channels, timesteps) + + """ + quantized = self.quantizer.decode(codes, st=st) + o = self.decoder(quantized) + return o diff --git a/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0581347c1300a65bfb84e4ae581526cc6edcc1ca --- /dev/null +++ b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch modules.""" + +# flake8: noqa +from .conv import ( + pad1d, + unpad1d, + NormConv1d, + NormConvTranspose1d, + NormConv2d, + NormConvTranspose2d, + SConv1d, + SConvTranspose1d, +) +from .lstm import SLSTM +from .seanet import SEANetEncoder, SEANetDecoder diff --git a/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..0352b8bfa322b0c166bb068fa18c3c3a46cb498e --- /dev/null +++ b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py @@ -0,0 +1,346 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Convolutional layers wrappers and utilities.""" + +import math +import typing as tp +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +from .norm import ConvLayerNorm + + +CONV_NORMALIZATIONS = frozenset( + [ + "none", + "weight_norm", + "spectral_norm", + "time_layer_norm", + "layer_norm", + "time_group_norm", + ] +) + + +def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == "weight_norm": + return weight_norm(module) + elif norm == "spectral_norm": + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module( + module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs +) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == "layer_norm": + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == "time_group_norm": + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d( + x: torch.Tensor, + paddings: tp.Tuple[int, int], + mode: str = "zero", + value: float = 0.0, +): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left:end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.convtr = apply_parametrization_norm( + nn.ConvTranspose1d(*args, **kwargs), norm + ) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.convtr = apply_parametrization_norm( + nn.ConvTranspose2d(*args, **kwargs), norm + ) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class SConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = "reflect", + ): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn( + "SConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." + ) + self.conv = NormConv1d( + in_channels, + out_channels, + kernel_size, + stride, + dilation=dilation, + groups=groups, + bias=bias, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + padding_total = (kernel_size - 1) * dilation - (stride - 1) + extra_padding = get_extra_padding_for_conv1d( + x, kernel_size, stride, padding_total + ) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d( + x, (padding_left, padding_right + extra_padding), mode=self.pad_mode + ) + return self.conv(x) + + +class SConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + causal: bool = False, + norm: str = "none", + trim_right_ratio: float = 1.0, + norm_kwargs: tp.Dict[str, tp.Any] = {}, + ): + super().__init__() + self.convtr = NormConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert ( + self.causal or self.trim_right_ratio == 1.0 + ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y diff --git a/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7e431238872e3175c5b379f69cc786bc0486a6 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py @@ -0,0 +1,46 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""LSTM layers module.""" + +from torch import nn + + +class SLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + + def __init__( + self, + dimension: int, + num_layers: int = 2, + skip: bool = True, + bidirectional: bool = False, + ): + super().__init__() + self.bidirectional = bidirectional + self.skip = skip + self.lstm = nn.LSTM( + dimension, dimension, num_layers, bidirectional=bidirectional + ) + + def forward(self, x): + x = x.permute(2, 0, 1) + y, _ = self.lstm(x) + if self.bidirectional: + x = x.repeat(1, 1, 2) + if self.skip: + y = y + x + y = y.permute(1, 2, 0) + return y diff --git a/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..ff5eaefd6b6103777d49c6fca2b071870371b7c5 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Normalization modules.""" + +import typing as tp + +import einops +import torch +from torch import nn + + +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + + def __init__( + self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs + ): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = einops.rearrange(x, "b ... t -> b t ...") + x = super().forward(x) + x = einops.rearrange(x, "b t ... -> b ... t") + return diff --git a/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79d90a1a2b310074fffdd0bd03cd3e60193a8de6 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py new file mode 100644 index 0000000000000000000000000000000000000000..5695ea84451a110875c530558b93d9ea915500c7 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py @@ -0,0 +1,317 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Arithmetic coder.""" + +import io +import math +import random +import typing as tp +import torch + +from ..binary import BitPacker, BitUnpacker + + +def build_stable_quantized_cdf( + pdf: torch.Tensor, + total_range_bits: int, + roundoff: float = 1e-8, + min_range: int = 2, + check: bool = True, +) -> torch.Tensor: + """Turn the given PDF into a quantized CDF that splits + [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional + to the PDF. + + Args: + pdf (torch.Tensor): probability distribution, shape should be `[N]`. + total_range_bits (int): see `ArithmeticCoder`, the typical range we expect + during the coding process is `[0, 2 ** total_range_bits - 1]`. + roundoff (float): will round the pdf up to that level to remove difference coming + from e.g. evaluating the Language Model on different architectures. + min_range (int): minimum range width. Should always be at least 2 for numerical + stability. Use this to avoid pathological behavior is a value + that is expected to be rare actually happens in real life. + check (bool): if True, checks that nothing bad happened, can be deactivated for speed. + """ + pdf = pdf.detach() + if roundoff: + pdf = (pdf / roundoff).floor() * roundoff + # interpolate with uniform distribution to achieve desired minimum probability. + total_range = 2**total_range_bits + cardinality = len(pdf) + alpha = min_range * cardinality / total_range + assert alpha <= 1, "you must reduce min_range" + ranges = (((1 - alpha) * total_range) * pdf).floor().long() + ranges += min_range + quantized_cdf = torch.cumsum(ranges, dim=-1) + if min_range < 2: + raise ValueError("min_range must be at least 2.") + if check: + assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] + if ( + (quantized_cdf[1:] - quantized_cdf[:-1]) < min_range + ).any() or quantized_cdf[0] < min_range: + raise ValueError("You must increase your total_range_bits.") + return quantized_cdf + + +class ArithmeticCoder: + """ArithmeticCoder, + Let us take a distribution `p` over `N` symbols, and assume we have a stream + of random variables `s_t` sampled from `p`. Let us assume that we have a budget + of `B` bits that we can afford to write on device. There are `2**B` possible numbers, + corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single + sequence `(s_t)` by doing the following: + + 1) Initialize the current range to` [0 ** 2 B - 1]`. + 2) For each time step t, split the current range into contiguous chunks, + one for each possible outcome, with size roughly proportional to `p`. + For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks + would be `{[0, 2], [3, 3]}`. + 3) Select the chunk corresponding to `s_t`, and replace the current range with this. + 4) When done encoding all the values, just select any value remaining in the range. + + You will notice that this procedure can fail: for instance if at any point in time + the range is smaller than `N`, then we can no longer assign a non-empty chunk to each + possible outcome. Intuitively, the more likely a value is, the less the range width + will reduce, and the longer we can go on encoding values. This makes sense: for any efficient + coding scheme, likely outcomes would take less bits, and more of them can be coded + with a fixed budget. + + In practice, we do not know `B` ahead of time, but we have a way to inject new bits + when the current range decreases below a given limit (given by `total_range_bits`), without + having to redo all the computations. If we encode mostly likely values, we will seldom + need to inject new bits, but a single rare value can deplete our stock of entropy! + + In this explanation, we assumed that the distribution `p` was constant. In fact, the present + code works for any sequence `(p_t)` possibly different for each timestep. + We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller + the KL between the true distribution and `p_t`, the most efficient the coding will be. + + Args: + fo (IO[bytes]): file-like object to which the bytes will be written to. + total_range_bits (int): the range `M` described above is `2 ** total_range_bits. + Any time the current range width fall under this limit, new bits will + be injected to rescale the initial range. + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + assert total_range_bits <= 30 + self.total_range_bits = total_range_bits + self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. + self.low: int = 0 + self.high: int = 0 + self.max_bit: int = -1 + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + + @property + def delta(self) -> int: + """Return the current range width.""" + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # If self.low and self.high start with the sames bits, + # those won't change anymore as we always just increase the range + # by powers of 2, and we can flush them out to the bit stream. + assert self.high >= self.low, (self.low, self.high) + assert self.high < 2 ** (self.max_bit + 1) + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + assert self.high >= self.low, (self.high, self.low, self.max_bit) + assert self.low >= 0 + self.max_bit -= 1 + self.packer.push(b1) + else: + break + + def push(self, symbol: int, quantized_cdf: torch.Tensor): + """Push the given symbol on the stream, flushing out bits + if possible. + + Args: + symbol (int): symbol to encode with the AC. + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. + """ + while self.delta < 2**self.total_range_bits: + self.low *= 2 + self.high = self.high * 2 + 1 + self.max_bit += 1 + + range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() + range_high = quantized_cdf[symbol].item() - 1 + effective_low = int( + math.ceil(range_low * (self.delta / (2**self.total_range_bits))) + ) + effective_high = int( + math.floor(range_high * (self.delta / (2**self.total_range_bits))) + ) + assert self.low <= self.high + self.high = self.low + effective_high + self.low = self.low + effective_low + assert self.low <= self.high, ( + effective_low, + effective_high, + range_low, + range_high, + ) + self._dbg.append((self.low, self.high)) + self._dbg2.append((self.low, self.high)) + outs = self._flush_common_prefix() + assert self.low <= self.high + assert self.max_bit >= -1 + assert self.max_bit <= 61, self.max_bit + return outs + + def flush(self): + """Flush the remaining information to the stream.""" + while self.max_bit >= 0: + b1 = (self.low >> self.max_bit) & 1 + self.packer.push(b1) + self.max_bit -= 1 + self.packer.flush() + + +class ArithmeticDecoder: + """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. + + Note that this must be called with **exactly** the same parameters and sequence + of quantized cdf as the arithmetic encoder or the wrong values will be decoded. + + If the AC encoder current range is [L, H], with `L` and `H` having the some common + prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. + For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside + `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained + for a specific sequence of symbols and a binary-search allows us to decode those symbols. + At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, + and we will need to read new bits from the stream and repeat the process. + + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + self.total_range_bits = total_range_bits + self.low: int = 0 + self.high: int = 0 + self.current: int = 0 + self.max_bit: int = -1 + self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. + # Following is for debugging + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + self._last: tp.Any = None + + @property + def delta(self) -> int: + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # Given the current range [L, H], if both have a common prefix, + # we know we can remove it from our representation to avoid handling large numbers. + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + self.current -= b1 << self.max_bit + assert self.high >= self.low + assert self.low >= 0 + self.max_bit -= 1 + else: + break + + def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: + """Pull a symbol, reading as many bits from the stream as required. + This returns `None` when the stream has been exhausted. + + Args: + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. This must be **exatly** + the same cdf as the one used at encoding time. + """ + while self.delta < 2**self.total_range_bits: + bit = self.unpacker.pull() + if bit is None: + return None + self.low *= 2 + self.high = self.high * 2 + 1 + self.current = self.current * 2 + bit + self.max_bit += 1 + + def bin_search(low_idx: int, high_idx: int): + # Binary search is not just for coding interviews :) + if high_idx < low_idx: + raise RuntimeError("Binary search failed") + mid = (low_idx + high_idx) // 2 + range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 + range_high = quantized_cdf[mid].item() - 1 + effective_low = int( + math.ceil(range_low * (self.delta / (2**self.total_range_bits))) + ) + effective_high = int( + math.floor(range_high * (self.delta / (2**self.total_range_bits))) + ) + low = effective_low + self.low + high = effective_high + self.low + if self.current >= low: + if self.current <= high: + return (mid, low, high, self.current) + else: + return bin_search(mid + 1, high_idx) + else: + return bin_search(low_idx, mid - 1) + + self._last = (self.low, self.high, self.current, self.max_bit) + sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) + self._dbg.append((self.low, self.high, self.current)) + self._flush_common_prefix() + self._dbg2.append((self.low, self.high, self.current)) + + return sym + + +def test(): + torch.manual_seed(1234) + random.seed(1234) + for _ in range(4): + pdfs = [] + cardinality = random.randrange(4000) + steps = random.randrange(100, 500) + fo = io.BytesIO() + encoder = ArithmeticCoder(fo) + symbols = [] + for step in range(steps): + pdf = torch.softmax(torch.randn(cardinality), dim=0) + pdfs.append(pdf) + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + symbol = torch.multinomial(pdf, 1).item() + symbols.append(symbol) + encoder.push(symbol, q_cdf) + encoder.flush() + + fo.seek(0) + decoder = ArithmeticDecoder(fo) + for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + decoded_symbol = decoder.pull(q_cdf) + assert decoded_symbol == symbol, idx + assert decoder.pull(torch.zeros(1)) is None + + +if __name__ == "__main__": + test() diff --git a/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..5799725598983bccb5c0644550f52303b15471c3 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py @@ -0,0 +1,388 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" +import typing as tp + +from einops import rearrange, repeat +import torch +from torch import nn +import torch.nn.functional as F + +from .distrib import broadcast_tensors, rank + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = ( + uniform_init if not kmeans_init else torch.zeros + ) + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + # broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + # broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = ( + nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + ) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward( + self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None + ): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + out_quantized = [] + + n_q = n_q or len(self.layers) + + for i, layer in enumerate(self.layers[:n_q]): + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + if layers and i in layers: + out_quantized.append(quantized) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses, out_quantized + + def encode( + self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None + ) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + st = st or 0 + for layer in self.layers[st:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[st + i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py new file mode 100644 index 0000000000000000000000000000000000000000..7b9a9b83e47cc3403354ce4c5e34eb0b279df5f2 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py @@ -0,0 +1,135 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch distributed utilities.""" + +import typing as tp + +import torch + + +def rank(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): + if is_distributed(): + return torch.distributed.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: tp.List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + # print('params[0].device ', params[0].device) + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError( + f"Mismatch in number of params: ours is {len(params)}, " + "at least one worker has a different one." + ) + + +def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + # src = int(rank()) # added code + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def sync_buffer(buffers, average=True): + """ + Sync grad for buffers. If average is False, broadcast instead of averaging. + """ + if not is_distributed(): + return + handles = [] + for buffer in buffers: + if torch.is_floating_point(buffer.data): + if average: + handle = torch.distributed.all_reduce( + buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True + ) + else: + handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True) + handles.append((buffer, handle)) + for buffer, handle in handles: + handle.wait() + if average: + buffer.data /= world_size + + +def sync_grad(params): + """ + Simpler alternative to DistributedDataParallel, that doesn't rely + on any black magic. For simple models it can also be as fast. + Just call this on your model parameters after the call to backward! + """ + if not is_distributed(): + return + handles = [] + for p in params: + if p.grad is not None: + handle = torch.distributed.all_reduce( + p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True + ) + handles.append((p, handle)) + for p, handle in handles: + handle.wait() + p.grad.data /= world_size() + + +def average_metrics(metrics: tp.Dict[str, float], count=1.0): + """Average a dictionary of metrics across all workers, using the optional + `count` as unormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged)) diff --git a/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py new file mode 100644 index 0000000000000000000000000000000000000000..ec7df0f9a0f58f2c757a710b682e10977ba58298 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py @@ -0,0 +1,125 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Residual vector quantizer implementation.""" + +from dataclasses import dataclass, field +import math +import typing as tp + +import torch +from torch import nn + +from .core_vq import ResidualVectorQuantization + + +@dataclass +class QuantizedResult: + quantized: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class ResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer. + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dimension: int = 256, + n_q: int = 8, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.n_q = n_q + self.dimension = dimension + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + self.vq = ResidualVectorQuantization( + dim=self.dimension, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + ) + + def forward( + self, + x: torch.Tensor, + n_q: tp.Optional[int] = None, + layers: tp.Optional[list] = None, + ) -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + n_q (int): Number of quantizer used to quantize. Default: All quantizers. + layers (list): Layer that need to return quantized. Defalt: None. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated numbert quantizers and layer quantized required to return. + """ + n_q = n_q if n_q else self.n_q + if layers and max(layers) >= n_q: + raise ValueError( + f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B." + ) + quantized, codes, commit_loss, quantized_list = self.vq( + x, n_q=n_q, layers=layers + ) + return quantized, codes, torch.mean(commit_loss), quantized_list + + def encode( + self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None + ) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizer to use + and returns indices for each quantizer. + Args: + x (torch.Tensor): Input tensor. + n_q (int): Number of quantizer used to quantize. Default: All quantizers. + st (int): Start to encode input from which layers. Default: 0. + """ + n_q = n_q if n_q else self.n_q + st = st or 0 + codes = self.vq.encode(x, n_q=n_q, st=st) + return codes + + def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor: + """Decode the given codes to the quantized representation. + Args: + codes (torch.Tensor): Input indices for each quantizer. + st (int): Start to decode input codes from which layers. Default: 0. + """ + quantized = self.vq.decode(codes, st=st) + return quantized diff --git a/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py new file mode 100644 index 0000000000000000000000000000000000000000..481de20c7ef05210e1bdf9092fe249060d06d686 --- /dev/null +++ b/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py @@ -0,0 +1,414 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# This source file is copied from https://github.com/facebookresearch/encodec + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Encodec SEANet-based encoder and decoder implementation.""" + +import typing as tp + +import numpy as np +import torch.nn as nn +import torch + +from . import SConv1d, SConvTranspose1d, SLSTM + + +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) + + +class SEANetResnetBlock(nn.Module): + """Residual block from SEANet model. + Args: + dim (int): Dimension of the input/output + kernel_sizes (list): List of kernel sizes for the convolutions. + dilations (list): List of dilations for the convolutions. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + compress (int): Reduced dimensionality in residual branches (from Demucs v3) + true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection. + """ + + def __init__( + self, + dim: int, + kernel_sizes: tp.List[int] = [3, 1], + dilations: tp.List[int] = [1, 1], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "weight_norm", + norm_params: tp.Dict[str, tp.Any] = {}, + causal: bool = False, + pad_mode: str = "reflect", + compress: int = 2, + true_skip: bool = True, + ): + super().__init__() + assert len(kernel_sizes) == len( + dilations + ), "Number of kernel sizes should match number of dilations" + act = getattr(nn, activation) if activation != "Snake" else Snake1d + hidden = dim // compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [ + act(**activation_params) if activation != "Snake" else act(in_chs), + SConv1d( + in_chs, + out_chs, + kernel_size=kernel_size, + dilation=dilation, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + self.block = nn.Sequential(*block) + self.shortcut: nn.Module + if true_skip: + self.shortcut = nn.Identity() + else: + self.shortcut = SConv1d( + dim, + dim, + kernel_size=1, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class SEANetEncoder(nn.Module): + """SEANet encoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of + upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here + that must match the decoder order + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + """ + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 1, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "weight_norm", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = False, + compress: int = 2, + lstm: int = 2, + bidirectional: bool = False, + ): + super().__init__() + self.channels = channels + self.dimension = dimension + self.n_filters = n_filters + self.ratios = list(reversed(ratios)) + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) # 计算乘积 + + act = getattr(nn, activation) if activation != "Snake" else Snake1d + mult = 1 + model: tp.List[nn.Module] = [ + SConv1d( + channels, + mult * n_filters, + kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + ] + # Downsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock( + mult * n_filters, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + norm=norm, + norm_params=norm_params, + activation=activation, + activation_params=activation_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] + + # Add downsampling layers + model += [ + ( + act(**activation_params) + if activation != "Snake" + else act(mult * n_filters) + ), + SConv1d( + mult * n_filters, + mult * n_filters * 2, + kernel_size=ratio * 2, + stride=ratio, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + mult *= 2 + + if lstm: + model += [ + SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional) + ] + + mult = mult * 2 if bidirectional else mult + model += [ + ( + act(**activation_params) + if activation != "Snake" + else act(mult * n_filters) + ), + SConv1d( + mult * n_filters, + dimension, + last_kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +class SEANetDecoder(nn.Module): + """SEANet decoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + final_activation (str): Final activation function after all convolutions. + final_activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. + If equal to 1.0, it means that all the trimming is done at the right. + """ + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 1, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + final_activation: tp.Optional[str] = None, + final_activation_params: tp.Optional[dict] = None, + norm: str = "weight_norm", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = False, + compress: int = 2, + lstm: int = 2, + trim_right_ratio: float = 1.0, + bidirectional: bool = False, + ): + super().__init__() + self.dimension = dimension + self.channels = channels + self.n_filters = n_filters + self.ratios = ratios + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + + act = getattr(nn, activation) if activation != "Snake" else Snake1d + mult = int(2 ** len(self.ratios)) + model: tp.List[nn.Module] = [ + SConv1d( + dimension, + mult * n_filters, + kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + ] + + if lstm: + model += [ + SLSTM(mult * n_filters, num_layers=lstm, bidirectional=bidirectional) + ] + + # Upsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add upsampling layers + model += [ + ( + act(**activation_params) + if activation != "Snake" + else act(mult * n_filters) + ), + SConvTranspose1d( + mult * n_filters, + mult * n_filters // 2, + kernel_size=ratio * 2, + stride=ratio, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + trim_right_ratio=trim_right_ratio, + ), + ] + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock( + mult * n_filters // 2, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + activation=activation, + activation_params=activation_params, + norm=norm, + norm_params=norm_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] + + mult //= 2 + + # Add final layers + model += [ + act(**activation_params) if activation != "Snake" else act(n_filters), + SConv1d( + n_filters, + channels, + last_kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + # Add optional final activation to decoder (eg. tanh) + if final_activation is not None: + final_act = getattr(nn, final_activation) + final_activation_params = final_activation_params or {} + model += [final_act(**final_activation_params)] + self.model = nn.Sequential(*model) + + def forward(self, z): + y = self.model(z) + return y + + +def test(): + import torch + + encoder = SEANetEncoder() + decoder = SEANetDecoder() + x = torch.randn(1, 1, 24000) + z = encoder(x) + print("z ", z.shape) + assert 1 == 2 + assert list(z.shape) == [1, 128, 75], z.shape + y = decoder(z) + assert y.shape == x.shape, (x.shape, y.shape) + + +if __name__ == "__main__": + test() diff --git a/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py b/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py new file mode 100755 index 0000000000000000000000000000000000000000..224ad25133cdec5b9aeb50de7b7a82612362b44f --- /dev/null +++ b/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py @@ -0,0 +1,592 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) ByteDance, Inc. and its affiliates. +# Copyright (c) Chutong Meng +# +# This source code is licensed under the CC BY-NC license found in the +# LICENSE file in the root directory of this source tree. +# Based on AudioDec (https://github.com/facebookresearch/AudioDec) + + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class VectorQuantize(nn.Module): + """Vector quantization w/ exponential moving averages (EMA)""" + + def __init__( + self, + dim: int, + codebook_size: int, + decay=0.8, + commitment=1.0, + eps=1e-5, + n_embed=None, + ): + super().__init__() + n_embed = self.default(n_embed, codebook_size) + + self.dim = dim + self.n_embed = n_embed + self.decay = decay + self.eps = eps + self.commitment = commitment + + embed = torch.randn(dim, n_embed) + self.register_buffer("embed", embed) + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", embed.clone()) + + @property + def codebook(self): + return self.embed.transpose(0, 1) + + def exists(self, val): + return val is not None + + def default(self, val, d): + return val if self.exists(val) else d + + def ema_inplace(self, moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + def laplace_smoothing(self, x, n_categories, eps=1e-5): + return (x + eps) / (x.sum() + n_categories * eps) + + def forward(self, input): + dtype = input.dtype + flatten = input.reshape(-1, self.dim) + dist = ( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ self.embed + + self.embed.pow(2).sum(0, keepdim=True) + ) + _, embed_ind = (-dist).max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = F.embedding(embed_ind, self.embed.transpose(0, 1)) + + if self.training: + self.ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = flatten.transpose(0, 1) @ embed_onehot + self.ema_inplace(self.embed_avg, embed_sum, self.decay) + cluster_size = ( + self.laplace_smoothing(self.cluster_size, self.n_embed, self.eps) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) + self.embed.data.copy_(embed_normalized) + + loss = F.mse_loss(quantize.detach(), input) * self.commitment + quantize = input + (quantize - input).detach() + + avg_probs = torch.mean(embed_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return quantize, loss, perplexity + + def forward_index(self, input): + dtype = input.dtype + flatten = input.reshape(-1, self.dim) + dist = ( + flatten.pow(2).sum(1, keepdim=True) + - 2 * flatten @ self.embed + + self.embed.pow(2).sum(0, keepdim=True) + ) + _, embed_ind = (-dist).max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = F.embedding(embed_ind, self.embed.transpose(0, 1)) + quantize = input + (quantize - input).detach() + + return quantize, embed_ind + + +class ResidualVQ(nn.Module): + """Residual VQ following algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantize(**kwargs) for _ in range(num_quantizers)] + ) + + def forward(self, x): + quantized_out = 0.0 + residual = x + all_losses = [] + all_perplexities = [] + for layer in self.layers: + quantized, loss, perplexity = layer(residual) + # Issue: https://github.com/lucidrains/vector-quantize-pytorch/issues/33 + # We found considering only the 1st layer VQ's graident results in better performance + # residual = residual - quantized.detach() # considering all layers' graidents + residual = ( + residual - quantized + ) # considering only the first layer's graident + quantized_out = quantized_out + quantized + all_losses.append(loss) + all_perplexities.append(perplexity) + all_losses, all_perplexities = map(torch.stack, (all_losses, all_perplexities)) + return quantized_out, all_losses, all_perplexities + + def forward_index(self, x, flatten_idx=False): + """ + all_indices: [num_of_quantizers, B, T] + """ + quantized_out = 0.0 + residual = x + all_indices = [] + for i, layer in enumerate(self.layers): + quantized, indices = layer.forward_index(residual) + # residual = residual - quantized.detach() + residual = residual - quantized + quantized_out = quantized_out + quantized + if flatten_idx: + indices += self.codebook_size * i + all_indices.append(indices) + all_indices = torch.stack(all_indices) + return quantized_out, all_indices + + def initial(self): + self.codebook = [] + for layer in self.layers: + self.codebook.append(layer.codebook) + self.codebook_size = self.codebook[0].size(0) + self.codebook = torch.stack(self.codebook) + self.codebook = self.codebook.reshape(-1, self.codebook.size(-1)) + + def lookup(self, indices): + quantized_out = F.embedding(indices, self.codebook) # Num x T x C + return torch.sum(quantized_out, dim=0, keepdim=True) + + +class Quantizer(nn.Module): + def __init__( + self, + code_dim: int, + codebook_num: int, + codebook_size: int, + ): + super().__init__() + self.codebook = ResidualVQ( + dim=code_dim, num_quantizers=codebook_num, codebook_size=codebook_size + ) + + def initial(self): + self.codebook.initial() + + def forward(self, z): + zq, vqloss, perplexity = self.codebook(z.transpose(2, 1)) + zq = zq.transpose(2, 1) + return zq, vqloss, perplexity + + def inference(self, z): + zq, indices = self.codebook.forward_index(z.transpose(2, 1)) + zq = zq.transpose(2, 1) + return zq, indices + + def encode(self, z): + zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True) + return zq, indices + + def decode(self, indices): + z = self.codebook.lookup(indices) + return z + + +class Conv1d1x1(nn.Conv1d): + """1x1 Conv1d.""" + + def __init__(self, in_channels, out_channels, bias=True): + super(Conv1d1x1, self).__init__( + in_channels, out_channels, kernel_size=1, bias=bias + ) + + +class Conv1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = -1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + if padding < 0: + padding = (kernel_size - 1) // 2 * dilation + self.dilation = dilation + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + """ + Args: + x (Tensor): Float tensor variable with the shape (B, C, T). + Returns: + Tensor: Float tensor variable with the shape (B, C, T). + """ + x = self.conv(x) + return x + + +class ConvTranspose1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding=-1, + output_padding=-1, + groups=1, + bias=True, + ): + super().__init__() + if padding < 0: + padding = (stride + 1) // 2 + if output_padding < 0: + output_padding = 1 if stride % 2 else 0 + self.deconv = nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + ) + + def forward(self, x): + """ + Args: + x (Tensor): Float tensor variable with the shape (B, C, T). + Returns: + Tensor: Float tensor variable with the shape (B, C', T'). + """ + x = self.deconv(x) + return x + + +class ResidualUnit(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + dilation=1, + bias=False, + nonlinear_activation="ELU", + nonlinear_activation_params={}, + ): + super().__init__() + self.activation = getattr(nn, nonlinear_activation)( + **nonlinear_activation_params + ) + self.conv1 = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + dilation=dilation, + bias=bias, + ) + self.conv2 = Conv1d1x1(out_channels, out_channels, bias) + + def forward(self, x): + y = self.conv1(self.activation(x)) + y = self.conv2(self.activation(y)) + return x + y + + +class Projector(nn.Module): + def __init__( + self, input_channels: int, code_dim: int, kernel_size=3, stride=1, bias=False + ): + super().__init__() + self.project = Conv1d( + input_channels, code_dim, kernel_size=kernel_size, stride=stride, bias=bias + ) + + def forward(self, x): + return self.project(x) + + +class EncoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + dilations=(1, 1), + unit_kernel_size=3, + bias=True, + ): + super().__init__() + self.res_units = torch.nn.ModuleList() + for dilation in dilations: + self.res_units += [ + ResidualUnit( + in_channels, + in_channels, + kernel_size=unit_kernel_size, + dilation=dilation, + ) + ] + self.num_res = len(self.res_units) + + self.conv = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=( + 3 if stride == 1 else (2 * stride) + ), # special case: stride=1, do not use kernel=2 + stride=stride, + bias=bias, + ) + + def forward(self, x): + for idx in range(self.num_res): + x = self.res_units[idx](x) + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + input_channels: int, + encode_channels: int, + channel_ratios=(1, 1), + strides=(1, 1), + kernel_size=3, + bias=True, + block_dilations=(1, 1), + unit_kernel_size=3, + ): + super().__init__() + assert len(channel_ratios) == len(strides) + + self.conv = Conv1d( + in_channels=input_channels, + out_channels=encode_channels, + kernel_size=kernel_size, + stride=1, + bias=False, + ) + self.conv_blocks = torch.nn.ModuleList() + in_channels = encode_channels + for idx, stride in enumerate(strides): + out_channels = int(encode_channels * channel_ratios[idx]) # could be float + self.conv_blocks += [ + EncoderBlock( + in_channels, + out_channels, + stride, + dilations=block_dilations, + unit_kernel_size=unit_kernel_size, + bias=bias, + ) + ] + in_channels = out_channels + self.num_blocks = len(self.conv_blocks) + self.out_channels = out_channels + + def forward(self, x): + x = self.conv(x) + for i in range(self.num_blocks): + x = self.conv_blocks[i](x) + return x + + +class DecoderBlock(nn.Module): + """Decoder block (no up-sampling)""" + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + dilations=(1, 1), + unit_kernel_size=3, + bias=True, + ): + super().__init__() + + if stride == 1: + self.conv = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape + stride=stride, + bias=bias, + ) + else: + self.conv = ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(2 * stride), + stride=stride, + bias=bias, + ) + + self.res_units = torch.nn.ModuleList() + for idx, dilation in enumerate(dilations): + self.res_units += [ + ResidualUnit( + out_channels, + out_channels, + kernel_size=unit_kernel_size, + dilation=dilation, + ) + ] + self.num_res = len(self.res_units) + + def forward(self, x): + x = self.conv(x) + for idx in range(self.num_res): + x = self.res_units[idx](x) + return x + + +class Decoder(nn.Module): + def __init__( + self, + code_dim: int, + output_channels: int, + decode_channels: int, + channel_ratios=(1, 1), + strides=(1, 1), + kernel_size=3, + bias=True, + block_dilations=(1, 1), + unit_kernel_size=3, + ): + super().__init__() + assert len(channel_ratios) == len(strides) + + self.conv1 = Conv1d( + in_channels=code_dim, + out_channels=int(decode_channels * channel_ratios[0]), + kernel_size=kernel_size, + stride=1, + bias=False, + ) + + self.conv_blocks = torch.nn.ModuleList() + for idx, stride in enumerate(strides): + in_channels = int(decode_channels * channel_ratios[idx]) + if idx < (len(channel_ratios) - 1): + out_channels = int(decode_channels * channel_ratios[idx + 1]) + else: + out_channels = decode_channels + self.conv_blocks += [ + DecoderBlock( + in_channels, + out_channels, + stride, + dilations=block_dilations, + unit_kernel_size=unit_kernel_size, + bias=bias, + ) + ] + self.num_blocks = len(self.conv_blocks) + + self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False) + + def forward(self, z): + x = self.conv1(z) + for i in range(self.num_blocks): + x = self.conv_blocks[i](x) + x = self.conv2(x) + return x + + +class VevoRepCodec(nn.Module): + def __init__( + self, + input_channels=768, + output_channels=768, + encode_channels=768, + decode_channels=768, + code_dim=768, + codebook_num=1, + codebook_size=1024, + bias=True, + enc_ratios=(1, 1), + dec_ratios=(1, 1), + enc_strides=(1, 1), + dec_strides=(1, 1), + enc_kernel_size=3, + dec_kernel_size=3, + enc_block_dilations=(1, 1), + enc_block_kernel_size=3, + dec_block_dilations=(1, 1), + dec_block_kernel_size=3, + ): + super().__init__() + + self.input_channels = input_channels + + self.encoder = Encoder( + input_channels=input_channels, + encode_channels=encode_channels, + channel_ratios=enc_ratios, + strides=enc_strides, + kernel_size=enc_kernel_size, + bias=bias, + block_dilations=enc_block_dilations, + unit_kernel_size=enc_block_kernel_size, + ) + + self.decoder = Decoder( + code_dim=code_dim, + output_channels=output_channels, + decode_channels=decode_channels, + channel_ratios=dec_ratios, + strides=dec_strides, + kernel_size=dec_kernel_size, + bias=bias, + block_dilations=dec_block_dilations, + unit_kernel_size=dec_block_kernel_size, + ) + + self.projector = Projector( + input_channels=self.encoder.out_channels, + code_dim=code_dim, + kernel_size=3, + stride=1, + bias=False, + ) + + self.quantizer = Quantizer( + code_dim=code_dim, codebook_num=codebook_num, codebook_size=codebook_size + ) + + def forward(self, x): + x = self.encoder(x) + z = self.projector(x) + zq, vqloss, perplexity = self.quantizer(z) + y = self.decoder(zq) + return y, zq, z, vqloss, perplexity diff --git a/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt b/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt new file mode 100644 index 0000000000000000000000000000000000000000..1a8ecb924668659f3b1a9c35b02b2f8839fd8c5a --- /dev/null +++ b/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9c176c2b8850ab2e3ba828bbfa969deaf4566ce55db5f2687b8430b87526ad2 +size 9343 diff --git a/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py b/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py new file mode 100644 index 0000000000000000000000000000000000000000..14c8ae842b5952b6c7ba3aa6e442dbb0277e2e87 --- /dev/null +++ b/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py @@ -0,0 +1,650 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel +import torch +import torch.nn.functional as F +import numpy as np +import os +import torch.nn as nn +from typing import List, Optional, Tuple, Union +import math + +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from transformers.models.llama.modeling_llama import BaseModelOutputWithPast + + +# sinusoidal positional encoding +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] * 1.0 + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class LlamaAdaptiveRMSNorm(nn.Module): + def __init__(self, hidden_size=1024, eps=1e-6, dim_cond=1024): + super().__init__() + self.to_weight = nn.Linear(dim_cond, hidden_size) + nn.init.zeros_(self.to_weight.weight) + nn.init.ones_(self.to_weight.bias) + self.variance_epsilon = eps + self._is_hf_initialized = True # disable automatic init + + def forward(self, hidden_states, cond_embedding): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + weight = self.to_weight(cond_embedding) + if len(weight.shape) == 2: + weight = weight.unsqueeze(1) + + return (weight * hidden_states).to(input_dtype) + + +class LlamaNARDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig, layer_idx: int): + """Override to adaptive layer norm""" + super().__init__(config, layer_idx) # init attention, mlp, etc. + self.input_layernorm = LlamaAdaptiveRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size + ) + self.post_attention_layernorm = LlamaAdaptiveRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size + ) + + # add `cond` in forward function + def forward( + self, + hidden_states: torch.Tensor, + cond_embedding: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm( + hidden_states, cond_embedding=cond_embedding + ) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm( + hidden_states, cond_embedding=cond_embedding + ) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + def __init__(self, config: LlamaConfig, layer_idx: int): + """Override to adaptive layer norm""" + super().__init__(config, layer_idx) # init attention, mlp, etc. + self.layer_idx = layer_idx + self.input_layernorm = LlamaAdaptiveRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size + ) + self.post_attention_layernorm = LlamaAdaptiveRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size + ) + + def forward( + self, + hidden_states: torch.Tensor, + cond_embedding: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm( + hidden_states, cond_embedding=cond_embedding + ) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm( + hidden_states, cond_embedding=cond_embedding + ) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class DiffLlama(LlamaModel): + def __init__( + self, + hidden_size=1024, + num_heads=16, + num_layers=16, + config=LlamaConfig(0, 256, 1024, 1, 1), + ): + super().__init__(config) + + self.layers = nn.ModuleList( + [ + LlamaNARDecoderLayer( + LlamaConfig( + hidden_size=hidden_size, + num_attention_heads=num_heads, + max_position_embeddings=4096, + intermediate_size=hidden_size * 4, + ), + layer_idx=i, + ) + for i in range(num_layers) + ] + ) + + self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size) + + self.diff_step_embedding = SinusoidalPosEmb(hidden_size) + self.diff_step_mlp = nn.Sequential( + nn.Linear(hidden_size, hidden_size * 4), + nn.SiLU(), + nn.Linear(hidden_size * 4, hidden_size), + ) + + # self.position_embedding = PositionalEncoding(hidden_size, dropout=0.0) + + self.cond_mlp = nn.Sequential( + nn.Linear(hidden_size, hidden_size * 4), + nn.SiLU(), + nn.Linear(hidden_size * 4, hidden_size), + ) + + for layer in self.layers: + layer.input_layernorm = LlamaAdaptiveRMSNorm( + hidden_size, dim_cond=hidden_size + ) + layer.post_attention_layernorm = LlamaAdaptiveRMSNorm( + hidden_size, dim_cond=hidden_size + ) + + self.post_init() + + # self.reset_parameters() + + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create noncausal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + def _expand_mask( + mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None + ): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = ( + mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + ) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + x, + diffusion_step, + cond, + x_mask, + input_ids: torch.LongTensor = None, # [num_quant, B, T] + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + # retrieve some shape info + batch_size, seq_length, _ = x.shape + + # condtion mlp + cond_embedding = self.cond_mlp(cond) # (B, T, C) + + # diffusion step embedding + diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device) + diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C) + x = x + cond_embedding + + inputs_embeds = x + attention_mask = x_mask + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + raise NotImplementedError + + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cond_embedding=diffusion_step, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + return hidden_states + + +class DiffLlamaPrefix(LlamaModel): + def __init__( + self, + hidden_size=1024, + num_heads=16, + num_layers=16, + config=LlamaConfig(0, 256, 1024, 1, 1), + ): + super().__init__(config) + + self.layers = nn.ModuleList( + [ + LlamaNARDecoderLayer( + LlamaConfig( + hidden_size=hidden_size, + num_attention_heads=num_heads, + max_position_embeddings=4096, + intermediate_size=hidden_size * 4, + ), + layer_idx=i, + ) + for i in range(num_layers) + ] + ) + + self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size) + + self.diff_step_embedding = SinusoidalPosEmb(hidden_size) + self.diff_step_mlp = nn.Sequential( + nn.Linear(hidden_size, hidden_size * 4), + nn.SiLU(), + nn.Linear(hidden_size * 4, hidden_size), + ) + + self.cond_mlp = nn.Sequential( + nn.Linear(hidden_size, hidden_size * 4), + nn.SiLU(), + nn.Linear(hidden_size * 4, hidden_size), + ) + + for layer in self.layers: + layer.input_layernorm = LlamaAdaptiveRMSNorm( + hidden_size, dim_cond=hidden_size + ) + layer.post_attention_layernorm = LlamaAdaptiveRMSNorm( + hidden_size, dim_cond=hidden_size + ) + + self.embed_tokens = None + + self.post_init() + + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create noncausal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + def _expand_mask( + mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None + ): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = ( + mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + ) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + x, + diffusion_step, + x_mask, + phone_embedding: Optional[torch.LongTensor] = None, + phone_mask: Optional[torch.FloatTensor] = None, + input_ids: torch.LongTensor = None, # [num_quant, B, T] + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + # retrieve some shape info + + phone_embedding = self.cond_mlp(phone_embedding) # (B, T, C) + phone_length = phone_embedding.shape[1] + inputs_embeds = torch.cat([phone_embedding, x], dim=1) + attention_mask = torch.cat([phone_mask, x_mask], dim=1) + + # diffusion step embedding + diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device) + diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C) + + batch_size, seq_length, _ = inputs_embeds.shape + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + raise NotImplementedError + + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cond_embedding=diffusion_step, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + return hidden_states[ + :, + phone_length:, + ] diff --git a/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py b/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py new file mode 100644 index 0000000000000000000000000000000000000000..8e3fe1213943232b396871863cf2f57651d0af49 --- /dev/null +++ b/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py @@ -0,0 +1,503 @@ +# Copyright (c) 2024 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +import torch.nn as nn +import math +from einops import rearrange +from indextts.utils.maskgct.models.tts.maskgct.llama_nar import DiffLlama + + +def top_k(logits, thres=0.9): + k = math.ceil((1 - thres) * logits.shape[-1]) + val, ind = logits.topk(k, dim=-1) + probs = torch.full_like(logits, float("-inf")) + probs.scatter_(2, ind, val) + return probs + + +def log(t, eps=1e-10): + return torch.log(t + eps) + + +def gumbel_noise(t): + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + + +def gumbel_sample(t, temperature=1.0, dim=-1): + return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) + + +def top_k(logits, thres=0.9): + k = math.ceil((1 - thres) * logits.shape[-1]) + val, ind = logits.topk(k, dim=-1) + probs = torch.full_like(logits, float("-inf")) + probs.scatter_(2, ind, val) + return probs + + +def log(t, eps=1e-10): + return torch.log(t + eps) + + +def gumbel_noise(t): + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + + +def gumbel_sample(t, temperature=1.0, dim=-1): + return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) + + +class MaskGCT_S2A(nn.Module): + def __init__( + self, + num_quantizer=12, + hidden_size=1024, + num_layers=16, + num_heads=16, + codebook_size=1024, + cfg_scale=0.15, + mask_layer_schedule="linear", + cond_codebook_size=1024, + cond_dim=1024, + predict_layer_1=True, + cfg=None, + ): + super().__init__() + + num_quantizer = ( + cfg.num_quantizer + if cfg is not None and hasattr(cfg, "num_quantizer") + else num_quantizer + ) + hidden_size = ( + cfg.hidden_size + if cfg is not None and hasattr(cfg, "hidden_size") + else hidden_size + ) + num_layers = ( + cfg.num_layers + if cfg is not None and hasattr(cfg, "num_layers") + else num_layers + ) + num_heads = ( + cfg.num_heads + if cfg is not None and hasattr(cfg, "num_heads") + else num_heads + ) + codebook_size = ( + cfg.codebook_size + if cfg is not None and hasattr(cfg, "codebook_size") + else codebook_size + ) + cfg_scale = ( + cfg.cfg_scale + if cfg is not None and hasattr(cfg, "cfg_scale") + else cfg_scale + ) + mask_layer_schedule = ( + cfg.mask_layer_schedule + if cfg is not None and hasattr(cfg, "mask_layer_schedule") + else mask_layer_schedule + ) + cond_codebook_size = ( + cfg.cond_codebook_size + if cfg is not None and hasattr(cfg, "cond_codebook_size") + else cond_codebook_size + ) + cond_dim = ( + cfg.cond_dim if cfg is not None and hasattr(cfg, "cond_dim") else cond_dim + ) + predict_layer_1 = ( + cfg.predict_layer_1 + if cfg is not None and hasattr(cfg, "predict_layer_1") + else predict_layer_1 + ) + + self.num_quantizer = num_quantizer + self.hidden_size = hidden_size + self.codebook_size = codebook_size + self.num_layers = num_layers + self.num_heads = num_heads + self.cfg_scale = cfg_scale + self.mask_layer_schedule = mask_layer_schedule + self.cond_codebook_size = cond_codebook_size + self.cond_dim = cond_dim + self.predict_layer_1 = predict_layer_1 + + self.layer_emb = nn.Embedding(self.num_quantizer, self.hidden_size) + self.mask_emb = nn.Embedding(1, self.hidden_size) + + self.token_emb = torch.nn.ModuleList( + [ + nn.Embedding(self.codebook_size, self.hidden_size) + for _ in range(self.num_quantizer) + ] + ) + + self.to_logits = torch.nn.ModuleList( + [ + nn.Linear(self.hidden_size, self.codebook_size) + for _ in range(self.num_quantizer) + ] + ) + + self.cond_emb = nn.Embedding(cond_codebook_size, self.hidden_size) + + self.reset_parameters() + + self.diff_estimator = DiffLlama( + hidden_size=hidden_size, + num_heads=self.num_heads, + num_layers=num_layers, + ) + + def mask_prob(self, t): + return torch.sin(t * np.pi / 2).to(t.device) + + def mask_layer(self, t): + # print(self.predict_layer_1) + if self.mask_layer_schedule == "uniform": + if self.predict_layer_1: + mask_layer = torch.randint(0, self.num_quantizer, (1,)).to(t.device) + else: + mask_layer = torch.randint(1, self.num_quantizer, (1,)).to(t.device) + elif self.mask_layer_schedule == "cosine": + if self.predict_layer_1: + weights = torch.tensor( + [ + np.cos(i / self.num_quantizer * np.pi / 2) + for i in range(self.num_quantizer) + ] + ) + else: + weights = torch.tensor( + [0] + + [ + np.cos((i - 1) / self.num_quantizer * np.pi / 2) + for i in range(1, self.num_quantizer) + ] + ) + mask_layer = torch.multinomial(weights, 1).to(t.device) + elif self.mask_layer_schedule == "linear": + if self.predict_layer_1: + weights = torch.tensor( + [self.num_quantizer - i for i in range(self.num_quantizer)] + ) + else: + weights = torch.tensor( + [0] + + [ + self.num_quantizer - (i - 1) + for i in range(1, self.num_quantizer) + ] + ) + weights = weights / weights.sum() + mask_layer = torch.multinomial(weights, 1).to(t.device) + # print(mask_layer) + new_t = t + + return mask_layer, new_t + + def forward_diffusion(self, x0, t): + # x0: (B, T, num_quantizer) + mask_layer, new_t = self.mask_layer(t) # (1,) + mask_prob = self.mask_prob(new_t) # (B,) + mask_token = self.mask_emb(torch.zeros_like(mask_layer)) # (1, hidden_size) + + xt = torch.zeros(x0.shape[0], x0.shape[1], self.hidden_size).to(x0.device) + + cfg_scale = self.cfg_scale + + # get prompt len + if torch.rand(1) > cfg_scale: + prompt_len = torch.randint( + min(x0.shape[1] // 4, 5), x0.shape[1] // 2, (x0.shape[0],) + ).to( + x0.device + ) # (B,) + else: + prompt_len = torch.zeros(x0.shape[0]).to(x0) # (B,) + + # get is prompt + is_prompt = torch.zeros_like(x0[:, :, 0]) # (B, T) + col_indices = ( + torch.arange(is_prompt.shape[1]) + .repeat(is_prompt.shape[0], 1) + .to(prompt_len) + ) # (B, T) + is_prompt[col_indices < prompt_len.unsqueeze(1)] = 1 # (B, T) 1 if prompt + + for idx, token_emb_idx in enumerate(self.token_emb): + if idx < mask_layer: + xt = xt + token_emb_idx(x0[:, :, idx]) # (B, T, hidden_size) + + elif idx == mask_layer: + mask = torch.bernoulli( + torch.ones_like(x0[:, :, idx]) * mask_prob[..., None] + ) # mask if 1, not mask if 0 + # prompt part don't need to be masked + mask[is_prompt.bool()] = 0 + # Ensure at least one token is masked + mask_num = mask[:,].sum(dim=1, keepdim=False) + all_zero_mask = (mask_num == 0).bool() + row_indices_to_modify = torch.nonzero(all_zero_mask) + # mask the first token if all tokens are not masked (may mask pad if random indices) + mask[row_indices_to_modify, prompt_len[row_indices_to_modify]] = 1 + + mask = mask[..., None] # (B, T, 1) + xt = ( + xt + + mask * mask_token[:, None, :] + + (1 - mask) * token_emb_idx(x0[:, :, idx]) + ) # (B, T, hidden_size) + + else: + # prompt part don't need to be masked + xt = ( + xt + + token_emb_idx(x0[:, :, idx]) * is_prompt[..., None] + + mask_token * (1 - is_prompt[..., None]) + ) + + return xt, new_t, mask_layer, mask, prompt_len, mask_prob + + def loss_t(self, x0, x_mask, t, cond=None): + xt, new_t, mask_layer, mask, prompt_len, mask_prob = self.forward_diffusion( + x0, t + ) + # xt: (B, T, hidden_size) + # new_t: (B,) + # mask_layer: (1,) + # mask: (B, T, 1) mask if 1, not mask if 0 + # prompt_len: (B,) + # mask_prob: (B,) + + mask_layer_cond = self.layer_emb(mask_layer).unsqueeze(1) # (1, 1, hidden_size) + cond = cond + mask_layer_cond # (B, T, hidden_size) + + embeds = self.diff_estimator(xt, new_t, cond, x_mask) # (B, T, hidden_size) + + logits = self.to_logits[mask_layer.item()](embeds) # (B, T, codebook_size) + + # final mask used for loss calculation + final_mask = mask * x_mask[..., None] # (B, T, 1) + + return logits, mask_layer, final_mask, x0, prompt_len, mask_prob + + def compute_loss(self, x0, x_mask, cond=None): + # x0: (B, T, num_quantizer) + # x_mask: (B, T) mask is 0 for padding + t = torch.rand(x0.shape[0], device=x0.device, requires_grad=False) + t = torch.clamp(t, 1e-5, 1.0) + return self.loss_t(x0, x_mask, t, cond) + + def reset_parameters(self): + def _reset_parameters(m): + if isinstance(m, nn.MultiheadAttention): + if m._qkv_same_embed_dim: + nn.init.normal_(m.in_proj_weight, std=0.02) + else: + nn.init.normal_(m.q_proj_weight, std=0.02) + nn.init.normal_(m.k_proj_weight, std=0.02) + nn.init.normal_(m.v_proj_weight, std=0.02) + + if m.in_proj_bias is not None: + nn.init.constant_(m.in_proj_bias, 0.0) + nn.init.constant_(m.out_proj.bias, 0.0) + if m.bias_k is not None: + nn.init.xavier_normal_(m.bias_k) + if m.bias_v is not None: + nn.init.xavier_normal_(m.bias_v) + + elif ( + isinstance(m, nn.Conv1d) + or isinstance(m, nn.ConvTranspose1d) + or isinstance(m, nn.Conv2d) + or isinstance(m, nn.ConvTranspose2d) + ): + m.weight.data.normal_(0.0, 0.02) + + elif isinstance(m, nn.Linear): + m.weight.data.normal_(mean=0.0, std=0.02) + if m.bias is not None: + m.bias.data.zero_() + + elif isinstance(m, nn.Embedding): + m.weight.data.normal_(mean=0.0, std=0.02) + if m.padding_idx is not None: + m.weight.data[m.padding_idx].zero_() + + self.apply(_reset_parameters) + + @torch.no_grad() + def reverse_diffusion( + self, + cond, + prompt, + x_mask=None, + prompt_mask=None, + temp=1.5, + filter_thres=0.98, + max_layer=None, + gt_code=None, + n_timesteps=[10, 4, 4, 4, 4, 4, 4, 4], + cfg=1.0, + rescale_cfg=1.0, + ): + + assert ( + len(n_timesteps) == self.num_quantizer + ) # each layer has a number of steps + + prompt_code = prompt # (B, prompt_len, num_quantizer) + prompt_len = prompt_code.shape[1] + target_len = cond.shape[1] - prompt_len + + if x_mask == None: + x_mask = torch.ones(cond.shape[0], target_len).to(cond.device) # (B, T) + if prompt_mask == None: + prompt_mask = torch.ones(cond.shape[0], prompt_len).to( + cond.device + ) # (B, prompt_len) + + cum = torch.zeros(x_mask.shape[0], x_mask.shape[1], self.hidden_size).to( + x_mask.device + ) # (B, T, hidden_size) + + bsz, seq_len, _ = cum.shape + + choice_temp = 1.0 + start_temp = temp # temperature for sampling + start_choice_temp = choice_temp # temperature for choicing mask tokens + + if max_layer is None: + max_layer = self.num_quantizer + + xt = torch.LongTensor(bsz, seq_len, max_layer).to(x_mask.device) + + if gt_code is not None: + gt_layer = gt_code.shape[-1] + xt[:, :, :gt_layer] = gt_code + for i in range(gt_layer): + cum += self.token_emb[i](xt[:, :, i]) + else: + gt_layer = 0 + + for mask_layer in range(gt_layer, max_layer): + steps = n_timesteps[mask_layer] + to_logits = self.to_logits[mask_layer] + token_emb = self.token_emb[mask_layer] + mask_layer = torch.tensor(mask_layer).to(x_mask.device).long().unsqueeze(0) + mask_layer_cond = self.layer_emb(mask_layer).unsqueeze( + 1 + ) # (1,) -> (1, 1, hidden_size) + temp_cond = cond + mask_layer_cond # (B, T, hidden_size) + + mask_token = self.mask_emb(torch.zeros_like(mask_layer)) # (1, hidden_size) + mask = torch.full((bsz, seq_len, 1), True).to(x_mask.device) # (B, T, 1) + seq = torch.full((bsz, seq_len), 0).to(x_mask.device) + + h = 1.0 / steps + + # prompt_code: (B, prompt_len, num_quantizer) + cur_prompt = 0 + for idx, emb in enumerate(self.token_emb): + cur_prompt = cur_prompt + emb( + prompt_code[:, :, idx] + ) # (B, prompt_len, hidden_size) + + t_list = [1.0 - i * h for i in range(steps)] + t_list.append(0.0) + for i in range(steps): + t = t_list[i] * torch.ones(bsz).to(x_mask.device) + token = token_emb(seq) # (B, T, hidden_size) + cur = cum + mask * mask_token[:, None, :] + (~mask) * token + cur = cur + mask_token[:, None, :] * (max_layer - 1 - mask_layer) + + xt_input = torch.cat([cur_prompt, cur], dim=1) # (B, T, hidden_size) + xt_mask = torch.cat( + [prompt_mask, x_mask], dim=1 + ) # (B, T), mask is 0 for padding + + embeds = self.diff_estimator(xt_input, t, temp_cond, xt_mask) + embeds = embeds[:, prompt_len:, :] + + # cfg + if cfg > 0: + mask_embeds = self.diff_estimator( + cur, t, temp_cond[:, prompt_len:, :], x_mask + ) + pos_emb_std = embeds.std() # std(g_cond) + embeds = embeds + cfg * (embeds - mask_embeds) # g_cfg + rescale_embeds = embeds * pos_emb_std / embeds.std() # g_final + embeds = rescale_cfg * rescale_embeds + (1 - rescale_cfg) * embeds + + logits = to_logits(embeds) # (B, T, codebook_size) + annealing_scale = t_list[i] + + choice_temp = start_choice_temp * annealing_scale + temp = start_temp * annealing_scale + logits = top_k(logits, filter_thres) + + if i == steps - 1: + # greedy + if steps == 1: + temp = 0.2 + sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3)) + else: + sampled_ids = logits.argmax(dim=-1) + + else: + # sampling + sampled_ids = gumbel_sample(logits, temperature=max(temp, 1e-3)) + + seq = torch.where(mask.squeeze(-1), sampled_ids, seq) + + scores = logits.softmax(dim=-1) + scores = scores.gather(2, rearrange(sampled_ids, "b n -> b n 1")) + scores = rearrange(scores, "b n 1 -> b n") + + scores = choice_temp * gumbel_noise(scores) + scores + scores = 1 - scores + + next_t = t_list[i + 1] * torch.ones(bsz).to(x_mask.device) + + next_mask_num = (self.mask_prob(next_t) * seq_len).long()[0].item() + + if next_mask_num == 0: + break + scores = scores.masked_fill( + ~mask.squeeze(-1), -torch.finfo(scores.dtype).max + ) + + mask_indices = scores.topk(next_mask_num, dim=-1).indices + mask = torch.zeros_like(scores, dtype=torch.bool).scatter( + 1, mask_indices, True + ) + seq = seq.masked_fill(mask, 0) + + mask = mask.unsqueeze(-1) + + cum = cum + token_emb(seq) + xt[..., mask_layer.squeeze(0).item()] = seq + + return xt + + def forward(self, x0, x_mask, cond_code=None): + # x0: (B, T, num_quantizer) + # x_mask: (B, T) mask is 0 for padding + # cond_code: semantic token (B, T) + cond = self.cond_emb(cond_code) + + logits, mask_layer, final_mask, x0, prompt_len, mask_prob = self.compute_loss( + x0, + x_mask, + cond, + ) + return logits, mask_layer, final_mask, x0, prompt_len, mask_prob diff --git a/indextts/utils/maskgct_utils.py b/indextts/utils/maskgct_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..40b9cb0e159e40ee0cd69ad470af4db62f3f5d5e --- /dev/null +++ b/indextts/utils/maskgct_utils.py @@ -0,0 +1,259 @@ +import torch +import librosa +import json5 +from huggingface_hub import hf_hub_download +from transformers import SeamlessM4TFeatureExtractor, Wav2Vec2BertModel +import safetensors +import numpy as np + +from indextts.utils.maskgct.models.codec.kmeans.repcodec_model import RepCodec +from indextts.utils.maskgct.models.tts.maskgct.maskgct_s2a import MaskGCT_S2A +from indextts.utils.maskgct.models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder +import time + + +def _load_config(config_fn, lowercase=False): + """Load configurations into a dictionary + + Args: + config_fn (str): path to configuration file + lowercase (bool, optional): whether changing keys to lower case. Defaults to False. + + Returns: + dict: dictionary that stores configurations + """ + with open(config_fn, "r") as f: + data = f.read() + config_ = json5.loads(data) + if "base_config" in config_: + # load configurations from new path + p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"]) + p_config_ = _load_config(p_config_path) + config_ = override_config(p_config_, config_) + if lowercase: + # change keys in config_ to lower case + config_ = get_lowercase_keys_config(config_) + return config_ + + +def load_config(config_fn, lowercase=False): + """Load configurations into a dictionary + + Args: + config_fn (str): path to configuration file + lowercase (bool, optional): _description_. Defaults to False. + + Returns: + JsonHParams: an object that stores configurations + """ + config_ = _load_config(config_fn, lowercase=lowercase) + # create an JsonHParams object with configuration dict + cfg = JsonHParams(**config_) + return cfg + + +class JsonHParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = JsonHParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + + +def build_semantic_model(path_='./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt'): + semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0") + semantic_model.eval() + stat_mean_var = torch.load(path_) + semantic_mean = stat_mean_var["mean"] + semantic_std = torch.sqrt(stat_mean_var["var"]) + return semantic_model, semantic_mean, semantic_std + + +def build_semantic_codec(cfg): + semantic_codec = RepCodec(cfg=cfg) + semantic_codec.eval() + return semantic_codec + + +def build_s2a_model(cfg, device): + soundstorm_model = MaskGCT_S2A(cfg=cfg) + soundstorm_model.eval() + soundstorm_model.to(device) + return soundstorm_model + + +def build_acoustic_codec(cfg, device): + codec_encoder = CodecEncoder(cfg=cfg.encoder) + codec_decoder = CodecDecoder(cfg=cfg.decoder) + codec_encoder.eval() + codec_decoder.eval() + codec_encoder.to(device) + codec_decoder.to(device) + return codec_encoder, codec_decoder + + +class Inference_Pipeline(): + def __init__( + self, + semantic_model, + semantic_codec, + semantic_mean, + semantic_std, + codec_encoder, + codec_decoder, + s2a_model_1layer, + s2a_model_full, + ): + self.semantic_model = semantic_model + self.semantic_codec = semantic_codec + self.semantic_mean = semantic_mean + self.semantic_std = semantic_std + + self.codec_encoder = codec_encoder + self.codec_decoder = codec_decoder + self.s2a_model_1layer = s2a_model_1layer + self.s2a_model_full = s2a_model_full + + @torch.no_grad() + def get_emb(self, input_features, attention_mask): + vq_emb = self.semantic_model( + input_features=input_features, + attention_mask=attention_mask, + output_hidden_states=True, + ) + feat = vq_emb.hidden_states[17] # (B, T, C) + feat = (feat - self.semantic_mean.to(feat)) / self.semantic_std.to(feat) + return feat + + @torch.no_grad() + def extract_acoustic_code(self, speech): + vq_emb = self.codec_encoder(speech.unsqueeze(1)) + _, vq, _, _, _ = self.codec_decoder.quantizer(vq_emb) + acoustic_code = vq.permute(1, 2, 0) + return acoustic_code + + @torch.no_grad() + def get_scode(self, inputs): + semantic_code, feat = self.semantic_codec.quantize(inputs) + # vq = self.semantic_codec.quantizer.vq2emb(semantic_code.unsqueeze(1)) + # vq = vq.transpose(1,2) + return semantic_code + + @torch.no_grad() + def semantic2acoustic( + self, + combine_semantic_code, + acoustic_code, + n_timesteps=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + cfg=2.5, + rescale_cfg=0.75, + ): + semantic_code = combine_semantic_code + + cond = self.s2a_model_1layer.cond_emb(semantic_code) + prompt = acoustic_code[:, :, :] + predict_1layer = self.s2a_model_1layer.reverse_diffusion( + cond=cond, + prompt=prompt, + temp=1.5, + filter_thres=0.98, + n_timesteps=n_timesteps[:1], + cfg=cfg, + rescale_cfg=rescale_cfg, + ) + + cond = self.s2a_model_full.cond_emb(semantic_code) + prompt = acoustic_code[:, :, :] + predict_full = self.s2a_model_full.reverse_diffusion( + cond=cond, + prompt=prompt, + temp=1.5, + filter_thres=0.98, + n_timesteps=n_timesteps, + cfg=cfg, + rescale_cfg=rescale_cfg, + gt_code=predict_1layer, + ) + + vq_emb = self.codec_decoder.vq2emb( + predict_full.permute(2, 0, 1), n_quantizers=12 + ) + recovered_audio = self.codec_decoder(vq_emb) + prompt_vq_emb = self.codec_decoder.vq2emb( + prompt.permute(2, 0, 1), n_quantizers=12 + ) + recovered_prompt_audio = self.codec_decoder(prompt_vq_emb) + recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy() + recovered_audio = recovered_audio[0][0].cpu().numpy() + combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio]) + + return combine_audio, recovered_audio + + def s2a_inference( + self, + prompt_speech_path, + combine_semantic_code, + cfg=2.5, + n_timesteps_s2a=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + cfg_s2a=2.5, + rescale_cfg_s2a=0.75, + ): + speech = librosa.load(prompt_speech_path, sr=24000)[0] + acoustic_code = self.extract_acoustic_code( + torch.tensor(speech).unsqueeze(0).to(combine_semantic_code.device) + ) + _, recovered_audio = self.semantic2acoustic( + combine_semantic_code, + acoustic_code, + n_timesteps=n_timesteps_s2a, + cfg=cfg_s2a, + rescale_cfg=rescale_cfg_s2a, + ) + + return recovered_audio + + @torch.no_grad() + def gt_inference( + self, + prompt_speech_path, + combine_semantic_code, + ): + speech = librosa.load(prompt_speech_path, sr=24000)[0] + ''' + acoustic_code = self.extract_acoustic_code( + torch.tensor(speech).unsqueeze(0).to(combine_semantic_code.device) + ) + prompt = acoustic_code[:, :, :] + prompt_vq_emb = self.codec_decoder.vq2emb( + prompt.permute(2, 0, 1), n_quantizers=12 + ) + ''' + + prompt_vq_emb = self.codec_encoder(torch.tensor(speech).unsqueeze(0).unsqueeze(1).to(combine_semantic_code.device)) + recovered_prompt_audio = self.codec_decoder(prompt_vq_emb) + recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy() + return recovered_prompt_audio diff --git a/indextts/utils/text_utils.py b/indextts/utils/text_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d8cfc6efafe2a2b79c1c34f8e3c58b182f90a2d7 --- /dev/null +++ b/indextts/utils/text_utils.py @@ -0,0 +1,41 @@ +import re + +from textstat import textstat + + +def contains_chinese(text): + # 正则表达式,用于匹配中文字符 + 数字 -> 都认为是 zh + if re.search(r'[\u4e00-\u9fff0-9]', text): + return True + return False + + +def get_text_syllable_num(text): + chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]') + number_char_pattern = re.compile(r'[0-9]') + syllable_num = 0 + tokens = re.findall(r'[\u4e00-\u9fff]+|[a-zA-Z]+|[0-9]+', text) + # print(tokens) + if contains_chinese(text): + for token in tokens: + if chinese_char_pattern.search(token) or number_char_pattern.search(token): + syllable_num += len(token) + else: + syllable_num += textstat.syllable_count(token) + else: + syllable_num = textstat.syllable_count(text) + + return syllable_num + + +def get_text_tts_dur(text): + min_speed = 3 # 2.18 # + max_speed = 5.50 + + ratio = 0.8517 if contains_chinese(text) else 1.0 + + syllable_num = get_text_syllable_num(text) + max_dur = syllable_num * ratio / max_speed + min_dur = syllable_num * ratio / min_speed + + return max_dur, min_dur \ No newline at end of file diff --git a/indextts/utils/typical_sampling.py b/indextts/utils/typical_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..0b225e94d7e328a2bda8ef342956df70576ca4d9 --- /dev/null +++ b/indextts/utils/typical_sampling.py @@ -0,0 +1,30 @@ +import torch +from transformers import TypicalLogitsWarper as BaseTypicalLogitsWarper + +class TypicalLogitsWarper(BaseTypicalLogitsWarper): + def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + super().__init__(mass=mass, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # calculate entropy + normalized = torch.nn.functional.log_softmax(scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (cumulative_probs < self.mass).sum(dim=1) + last_ind[last_ind < 0] = 0 + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores diff --git a/indextts/utils/utils.py b/indextts/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..82cf8786c201ec9f91e1f0fd6b2684b823e1c780 --- /dev/null +++ b/indextts/utils/utils.py @@ -0,0 +1,93 @@ +import os +import re +import random +import torch +import torchaudio + +MATPLOTLIB_FLAG = False + + +def load_audio(audiopath, sampling_rate): + audio, sr = torchaudio.load(audiopath) + #print(f"wave shape: {audio.shape}, sample_rate: {sr}") + + if audio.size(0) > 1: # mix to mono + audio = audio[0].unsqueeze(0) + + if sr != sampling_rate: + try: + audio = torchaudio.functional.resample(audio, sr, sampling_rate) + except Exception as e: + print(f"Warning: {audiopath}, wave shape: {audio.shape}, sample_rate: {sr}") + return None + # clip audio invalid values + audio.clip_(-1, 1) + return audio + + +def tokenize_by_CJK_char(line: str) -> str: + """ + Tokenize a line of text with CJK char. + + Note: All return charaters will be upper case. + + Example: + input = "你好世界是 hello world 的中文" + output = "你 好 世 界 是 HELLO WORLD 的 中 文" + + Args: + line: + The input text. + + Return: + A new string tokenize by CJK char. + """ + # The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py + pattern = re.compile( + r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])" + ) + chars = pattern.split(line.strip().upper()) + return " ".join([w.strip() for w in chars if w.strip()]) + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = torch.arange(0, + max_len, + dtype=torch.int64, + device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) diff --git a/indextts/utils/webui_utils.py b/indextts/utils/webui_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5895793cb615a4047a448fe28a7679010628110d --- /dev/null +++ b/indextts/utils/webui_utils.py @@ -0,0 +1,42 @@ +import gradio as gr + + +def html_center(text, label='p'): + return f"""
+ <{label} style="margin: 0; padding: 0;">{text} +
""" + + +def html_left(text, label='p'): + return f"""
+ <{label} style="margin: 0; padding: 0;">{text} +
""" + + +def next_page(page_number,sentences): + new_page_number = int(page_number) + 1 + update_page_number = gr.update(value=str(new_page_number)) + update_prev_page = gr.update(visible=True, interactive=True) + if len(sentences.values) <= new_page_number * 20: + update_next_page = gr.update(visible=False, interactive=False) + else: + update_next_page = gr.update(visible=True, interactive=True) + return update_page_number, update_next_page, update_prev_page + + +def prev_page(page_number): + new_page_number = int(page_number) - 1 + update_page_number = gr.update(value=str(new_page_number)) + if new_page_number == 1: + update_prev_page = gr.update(visible=False, interactive=False) + else: + update_prev_page = gr.update(visible=True, interactive=True) + update_next_page = gr.update(visible=True, interactive=True) + return update_page_number, update_next_page, update_prev_page + + +def update_current_texts(page_number,sentences): + start_index = (int(page_number) - 1) * 20 + end_index = int(page_number) * 20 + current_texts = sentences.values[start_index:end_index if end_index < len(sentences.values) else len(sentences.values)] + return gr.update(values=current_texts) diff --git a/indextts/utils/xtransformers.py b/indextts/utils/xtransformers.py new file mode 100644 index 0000000000000000000000000000000000000000..5470476357f40940477fe4461bbbff6ed65af8c8 --- /dev/null +++ b/indextts/utils/xtransformers.py @@ -0,0 +1,1247 @@ +import math +from collections import namedtuple +from functools import partial +from inspect import isfunction + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import einsum, nn + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates', + 'past_key_values', +]) + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def cast_tuple(val, depth): + return val if isinstance(val, tuple) else (val,) * depth + + +class always(): + def __init__(self, val): + self.val = val + + def __call__(self, *args, **kwargs): + return self.val + + +class not_equals(): + def __init__(self, val): + self.val = val + + def __call__(self, x, *args, **kwargs): + return x != self.val + + +class equals(): + def __init__(self, val): + self.val = val + + def __call__(self, x, *args, **kwargs): + return x == self.val + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +# init helpers + +def init_zero_(layer): + nn.init.constant_(layer.weight, 0.) + if exists(layer.bias): + nn.init.constant_(layer.bias, 0.) + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# activations + +class ReluSquared(nn.Module): + def forward(self, x): + return F.relu(x) ** 2 + + +# positional embeddings + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim ** -0.5 + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + pos_emb = self.emb(n) + pos_emb = rearrange(pos_emb, 'n d -> () n d') + return pos_emb * self.scale + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return rearrange(emb, 'n d -> () n d') + + +class RelativePositionBias(nn.Module): + def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8): + super().__init__() + self.scale = scale + self.causal = causal + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + if not causal: + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + else: + n = torch.max(n, torch.zeros_like(n)) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, qk_dots): + i, j, device = *qk_dots.shape[-2:], qk_dots.device + q_pos = torch.arange(i, dtype=torch.long, device=device) + k_pos = torch.arange(j, dtype=torch.long, device=device) + rel_pos = k_pos[None, :] - q_pos[:, None] + rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, + max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + bias = rearrange(values, 'i j h -> () h i j') + return qk_dots + (bias * self.scale) + + +class AlibiPositionalBias(nn.Module): + def __init__(self, heads, **kwargs): + super().__init__() + self.heads = heads + slopes = torch.Tensor(self._get_slopes(heads)) + slopes = rearrange(slopes, 'h -> () h () ()') + self.register_buffer('slopes', slopes, persistent=False) + self.register_buffer('bias', None, persistent=False) + + @staticmethod + def _get_slopes(heads): + def get_slopes_power_of_2(n): + start = (2 ** (-2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio ** i for i in range(n)] + + if math.log2(heads).is_integer(): + return get_slopes_power_of_2(heads) + + closest_power_of_2 = 2 ** math.floor(math.log2(heads)) + return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ + :heads - closest_power_of_2] + + def forward(self, qk_dots): + h, i, j, device = *qk_dots.shape[-3:], qk_dots.device + + if exists(self.bias) and self.bias.shape[-1] >= j: + return qk_dots + self.bias[..., :j] + + bias = torch.arange(j, device=device) + bias = rearrange(bias, 'j -> () () () j') + bias = bias * self.slopes + + num_heads_unalibied = h - bias.shape[1] + bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied)) + + self.register_buffer('bias', bias, persistent=False) + return qk_dots + self.bias + + +class LearnedAlibiPositionalBias(AlibiPositionalBias): + def __init__(self, heads, bidirectional=False): + super().__init__(heads) + los_slopes = torch.log(self.slopes) + self.learned_logslopes = nn.Parameter(los_slopes) + + self.bidirectional = bidirectional + if self.bidirectional: + self.learned_logslopes_future = nn.Parameter(los_slopes) + + def forward(self, qk_dots): + h, i, j, device = *qk_dots.shape[-3:], qk_dots.device + + def get_slopes(param): + return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1])) + + if exists(self.bias) and self.bias.shape[-1] >= j: + bias = self.bias[..., :i, :j] + else: + i_arange = torch.arange(i, device=device) + j_arange = torch.arange(j, device=device) + bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1') + self.register_buffer('bias', bias, persistent=False) + + if self.bidirectional: + past_slopes = get_slopes(self.learned_logslopes) + future_slopes = get_slopes(self.learned_logslopes_future) + bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes) + else: + slopes = get_slopes(self.learned_logslopes) + bias = bias * slopes + + return qk_dots + bias + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, max_seq_len, device): + t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq) + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return rearrange(emb, 'n d -> () () n d') + + +def rotate_half(x): + x = rearrange(x, '... (j d) -> ... j d', j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs): + seq_len = t.shape[-2] + freqs = freqs[:, :, -seq_len:] + return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) + + +# norms + +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + out = self.fn(x, **kwargs) + scale_fn = lambda t: t * self.value + + if not isinstance(out, tuple): + return scale_fn(out) + + return (scale_fn(out[0]), *out[1:]) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + out = self.fn(x, **kwargs) + rezero_fn = lambda t: t * self.g + + if not isinstance(out, tuple): + return rezero_fn(out) + + return (rezero_fn(out[0]), *out[1:]) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSScaleShiftNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + self.scale_shift_process = nn.Linear(dim * 2, dim * 2) + + def forward(self, x, norm_scale_shift_inp): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + norm = x / norm.clamp(min=self.eps) * self.g + + ss_emb = self.scale_shift_process(norm_scale_shift_inp) + scale, shift = torch.chunk(ss_emb, 2, dim=1) + h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return h + + +# residual and residual gates + +class Residual(nn.Module): + def __init__(self, dim, scale_residual=False): + super().__init__() + self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None + + def forward(self, x, residual): + if exists(self.residual_scale): + residual = residual * self.residual_scale + + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim, scale_residual=False): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None + + def forward(self, x, residual): + if exists(self.residual_scale): + residual = residual * self.residual_scale + + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# token shifting + +def shift(t, amount, mask=None): + if amount == 0: + return t + + if exists(mask): + t = t.masked_fill(~mask[..., None], 0.) + + return F.pad(t, (0, 0, amount, -amount), value=0.) + + +class ShiftTokens(nn.Module): + def __init__(self, shifts, fn): + super().__init__() + self.fn = fn + self.shifts = tuple(shifts) + + def forward(self, x, **kwargs): + mask = kwargs.get('mask', None) + shifts = self.shifts + segments = len(shifts) + feats_per_shift = x.shape[-1] // segments + splitted = x.split(feats_per_shift, dim=-1) + segments_to_shift, rest = splitted[:segments], splitted[segments:] + segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts))) + x = torch.cat((*segments_to_shift, *rest), dim=-1) + return self.fn(x, **kwargs) + + +# feedforward + +class GLU(nn.Module): + def __init__(self, dim_in, dim_out, activation): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * self.act(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out=None, + mult=4, + glu=False, + relu_squared=False, + post_act_ln=False, + dropout=0., + zero_init_output=False + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + activation = ReluSquared() if relu_squared else nn.GELU() + + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + activation + ) if not glu else GLU(dim, inner_dim, activation) + + self.net = nn.Sequential( + project_in, + nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + # init last linear layer to 0 + if zero_init_output: + init_zero_(self.net[-1]) + + def forward(self, x): + return self.net(x) + + +# attention. + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + talking_heads=False, + head_scale=False, + collab_heads=False, + collab_compression=.3, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False, + gate_values=False, + zero_init_output=False, + max_attend_past=None, + qk_norm=False, + scale_init_value=None, + rel_pos_bias=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + ): + super().__init__() + self.scale = dim_head ** -0.5 + + self.heads = heads + self.causal = causal + self.max_attend_past = max_attend_past + + qk_dim = v_dim = dim_head * heads + + # collaborative heads + self.collab_heads = collab_heads + if self.collab_heads: + qk_dim = int(collab_compression * qk_dim) + self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim)) + + self.to_q = nn.Linear(dim, qk_dim, bias=False) + self.to_k = nn.Linear(dim, qk_dim, bias=False) + self.to_v = nn.Linear(dim, v_dim, bias=False) + + self.dropout = nn.Dropout(dropout) + + # add GLU gating for aggregated values, from alphafold2 + self.to_v_gate = None + if gate_values: + self.to_v_gate = nn.Linear(dim, v_dim) + nn.init.constant_(self.to_v_gate.weight, 0) + nn.init.constant_(self.to_v_gate.bias, 1) + + # cosine sim attention + self.qk_norm = qk_norm + if qk_norm: + scale_init_value = default(scale_init_value, + -3) # if not provided, initialize as though it were sequence length of 1024 + self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # head scaling + self.head_scale = head_scale + if head_scale: + self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim) + + self.rel_pos_bias = rel_pos_bias + if rel_pos_bias: + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads, + num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance) + + # init output projection 0 + if zero_init_output: + init_zero_(self.to_out) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + attn_mask=None, + sinusoidal_emb=None, + rotary_pos_emb=None, + prev_attn=None, + mem=None, + layer_past=None, + ): + b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists( + context) + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + if not collab_heads: + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + else: + q = einsum('b i d, h d -> b h i d', q, self.collab_mixing) + k = rearrange(k, 'b n d -> b () n d') + v = rearrange(v, 'b n (h d) -> b h n d', h=h) + + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat([past_key, k], dim=-2) + v = torch.cat([past_value, v], dim=-2) + k_cache = k + v_cache = v + + if exists(rotary_pos_emb) and not has_context: + l = rotary_pos_emb.shape[-1] + (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) + ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) + q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + if collab_heads: + k = k.expand(-1, h, -1, -1) + + if self.qk_norm: + q, k = map(l2norm, (q, k)) + scale = 1 / (self.scale.exp().clamp(min=1e-2)) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots.clone() + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if self.rel_pos_bias: + dots = self.rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if exists(attn_mask): + assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4' + if attn_mask.ndim == 2: + attn_mask = rearrange(attn_mask, 'i j -> () () i j') + elif attn_mask.ndim == 3: + attn_mask = rearrange(attn_mask, 'h i j -> () h i j') + dots.masked_fill_(~attn_mask, mask_value) + + if exists(self.max_attend_past): + i, j = dots.shape[-2:] + range_q = torch.arange(j - i, j, device=device) + range_k = torch.arange(j, device=device) + dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j') + mask = dist > self.max_attend_past + dots.masked_fill_(mask, mask_value) + del mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn.clone() + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + + if head_scale: + out = out * self.head_scale_params + + out = rearrange(out, 'b h n d -> b n (h d)') + + if exists(self.to_v_gate): + gates = self.to_v_gate(x) + out = out * gates.sigmoid() + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates, k_cache, v_cache + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rms_scaleshift_norm=False, + use_rmsnorm=False, + use_rezero=False, + alibi_pos_bias=False, + alibi_num_heads=None, + alibi_learned=False, + position_infused_attn=False, + rotary_pos_emb=False, + rotary_emb_dim=None, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + scale_residual=False, + shift_tokens=0, + sandwich_norm=False, + use_qk_norm_attn=False, + qk_norm_attn_seq_len=None, + zero_init_branch_output=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + self.causal = causal + + rel_pos_bias = 'rel_pos_bias' in attn_kwargs + self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + + rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) + self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None + + assert not ( + alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' + + if alibi_pos_bias: + alibi_num_heads = default(alibi_num_heads, heads) + assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' + alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias + self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal) + else: + self.rel_pos = None + + assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm' + self.pre_norm = pre_norm + self.sandwich_norm = sandwich_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + self.cross_attend = cross_attend + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + # qk normalization + + if use_qk_norm_attn: + attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists( + qk_norm_attn_seq_len) else None + attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value} + + # zero init + + if zero_init_branch_output: + attn_kwargs = {**attn_kwargs, 'zero_init_output': True} + ff_kwargs = {**ff_kwargs, 'zero_init_output': True} + + # calculate layer block order + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + # calculate token shifting + + shift_tokens = cast_tuple(shift_tokens, len(layer_types)) + + # iterate and construct layers + + for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): + is_last_layer = ind == (len(self.layer_types) - 1) + + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if layer_shift_tokens > 0: + shift_range_upper = layer_shift_tokens + 1 + shift_range_lower = -layer_shift_tokens if not causal else 0 + layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer) + + if exists(branch_fn): + layer = branch_fn(layer) + + residual_fn = GRUGating if gate_residual else Residual + residual = residual_fn(dim, scale_residual=scale_residual) + + layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c') + + pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None + post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None + post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None + + norms = nn.ModuleList([ + pre_branch_norm, + post_branch_norm, + post_main_norm + ]) + + self.layers.append(nn.ModuleList([ + norms, + layer, + residual + ])) + + def forward( + self, + x, + context=None, + full_context=None, # for passing a list of hidden states from an encoder + mask=None, + context_mask=None, + attn_mask=None, + mems=None, + return_hiddens=False, + norm_scale_shift_inp=None, + past_key_values=None, + expected_seq_len=None, + ): + + assert not (self.cross_attend ^ (exists(context) or exists( + full_context))), 'context must be passed in if cross_attend is set to True' + assert context is None or full_context is None, 'only one of full_context or context can be provided' + + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + norm_args = {} + if exists(norm_scale_shift_inp): + norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp + + rotary_pos_emb = None + if exists(self.rotary_pos_emb): + if not self.training and self.causal: + assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`" + elif expected_seq_len is None: + expected_seq_len = 0 + seq_len = x.shape[1] + if past_key_values is not None: + seq_len += past_key_values[0][0].shape[-2] + max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]) + rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) + + present_key_values = [] + cross_attn_count = 0 + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + if layer_type == 'a': + layer_mem = mems.pop(0) if mems else None + + residual = x + + pre_branch_norm, post_branch_norm, post_main_norm = norm + + if exists(pre_branch_norm): + x = pre_branch_norm(x, **norm_args) + + if layer_type == 'a' or layer_type == 'c': + if past_key_values is not None: + layer_kv = past_key_values.pop(0) + layer_past = tuple(s.to(x.device) for s in layer_kv) + else: + layer_past = None + + if layer_type == 'a': + out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, + prev_attn, layer_mem, layer_past) + elif layer_type == 'c': + if exists(full_context): + out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None, + None, prev_attn, None, layer_past) + else: + out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past) + elif layer_type == 'f': + out = block(x) + + if layer_type == 'a' or layer_type == 'c' and present_key_values is not None: + present_key_values.append((k.detach(), v.detach())) + + if exists(post_branch_norm): + out = post_branch_norm(out, **norm_args) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if exists(post_main_norm): + x = post_main_norm(x, **norm_args) + + if layer_type == 'c': + cross_attn_count += 1 + + if layer_type == 'f': + hiddens.append(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates, + past_key_values=present_key_values + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + +class Decoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on decoder' + super().__init__(causal=True, **kwargs) + + +class CrossAttender(AttentionLayers): + def __init__(self, **kwargs): + super().__init__(cross_attend=True, only_cross=True, **kwargs) + + +class ViTransformerWrapper(nn.Module): + def __init__( + self, + *, + image_size, + patch_size, + attn_layers, + num_classes=None, + dropout=0., + emb_dropout=0. + ): + super().__init__() + assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder' + assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' + dim = attn_layers.dim + num_patches = (image_size // patch_size) ** 2 + patch_dim = 3 * patch_size ** 2 + + self.patch_size = patch_size + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.patch_to_embedding = nn.Linear(patch_dim, dim) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None + + def forward( + self, + img, + return_embeddings=False + ): + p = self.patch_size + + x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) + x = self.patch_to_embedding(x) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embedding[:, :(n + 1)] + x = self.dropout(x) + + x = self.attn_layers(x) + x = self.norm(x) + + if not exists(self.mlp_head) or return_embeddings: + return x + + return self.mlp_head(x[:, 0]) + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + shift_mem_down=0, + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.shift_mem_down = shift_mem_down + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + def init_(self): + nn.init.kaiming_normal_(self.token_emb.weight) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_hiddens=False, + return_attn=False, + mems=None, + use_cache=False, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x = x + self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + if self.shift_mem_down and exists(mems): + mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:] + mems = [*mems_r, *mems_l] + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_hiddens: + hiddens = intermediates.hiddens + return out, hiddens + + res = [out] + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + res.append(attn_maps) + if use_cache: + res.append(intermediates.past_key_values) + + if len(res) > 1: + return tuple(res) + return res[0] + + +class ContinuousTransformerWrapper(nn.Module): + def __init__( + self, + *, + max_seq_len, + attn_layers, + dim_in=None, + dim_out=None, + emb_dim=None, + emb_dropout=0., + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + + self.max_seq_len = max_seq_len + + self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() + + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_attn=False, + mems=None, + use_cache=False, + **kwargs + ): + b, n, _, device = *x.shape, x.device + + x = self.project_in(x) + x = x + self.pos_emb(x) + x = self.emb_dropout(x) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + out = self.project_out(x) if not return_embeddings else x + + res = [out] + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + res.append(attn_maps) + if use_cache: + res.append(intermediates.past_key_values) + + if len(res) > 1: + return tuple(res) + return res[0] diff --git a/indextts/vqvae/__init__.py b/indextts/vqvae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/indextts/vqvae/xtts_dvae.py b/indextts/vqvae/xtts_dvae.py new file mode 100644 index 0000000000000000000000000000000000000000..0e101b0369d5c53e749b52307947975be68e085f --- /dev/null +++ b/indextts/vqvae/xtts_dvae.py @@ -0,0 +1,395 @@ +import functools +from math import sqrt + +import torch +import torch.distributed as distributed +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from einops import rearrange + + +def default(val, d): + return val if val is not None else d + + +def eval_decorator(fn): + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + + return inner + + +def dvae_wav_to_mel( + wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu") +): + mel_stft = torchaudio.transforms.MelSpectrogram( + n_fft=1024, + hop_length=256, + win_length=1024, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + norm="slaney", + ).to(device) + wav = wav.to(device) + mel = mel_stft(wav) + mel = torch.log(torch.clamp(mel, min=1e-5)) + if mel_norms is None: + mel_norms = torch.load(mel_norms_file, map_location=device) + mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) + return mel + + +class Quantize(nn.Module): + def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False): + super().__init__() + + self.dim = dim + self.n_embed = n_embed + self.decay = decay + self.eps = eps + + self.balancing_heuristic = balancing_heuristic + self.codes = None + self.max_codes = 64000 + self.codes_full = False + self.new_return_order = new_return_order + + embed = torch.randn(dim, n_embed) + self.register_buffer("embed", embed) + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", embed.clone()) + + def forward(self, input, return_soft_codes=False): + if self.balancing_heuristic and self.codes_full: + h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes) + mask = torch.logical_or(h > 0.9, h < 0.01).unsqueeze(1) + ep = self.embed.permute(1, 0) + ea = self.embed_avg.permute(1, 0) + rand_embed = torch.randn_like(ep) * mask + self.embed = (ep * ~mask + rand_embed).permute(1, 0) + self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0) + self.cluster_size = self.cluster_size * ~mask.squeeze() + if torch.any(mask): + print(f"Reset {torch.sum(mask)} embedding codes.") + self.codes = None + self.codes_full = False + + flatten = input.reshape(-1, self.dim) + dist = flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ self.embed + self.embed.pow(2).sum(0, keepdim=True) + soft_codes = -dist + _, embed_ind = soft_codes.max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = self.embed_code(embed_ind) + + if self.balancing_heuristic: + if self.codes is None: + self.codes = embed_ind.flatten() + else: + self.codes = torch.cat([self.codes, embed_ind.flatten()]) + if len(self.codes) > self.max_codes: + self.codes = self.codes[-self.max_codes :] + self.codes_full = True + + if self.training: + embed_onehot_sum = embed_onehot.sum(0) + embed_sum = flatten.transpose(0, 1) @ embed_onehot + + if distributed.is_initialized() and distributed.get_world_size() > 1: + distributed.all_reduce(embed_onehot_sum) + distributed.all_reduce(embed_sum) + + self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay) + self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) + n = self.cluster_size.sum() + cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n + embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) + self.embed.data.copy_(embed_normalized) + + diff = (quantize.detach() - input).pow(2).mean() + quantize = input + (quantize - input).detach() + + if return_soft_codes: + return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,)) + elif self.new_return_order: + return quantize, embed_ind, diff + else: + return quantize, diff, embed_ind + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.transpose(0, 1)) + + +# Fits a soft-discretized input to a normal-PDF across the specified dimension. +# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete +# values with the specified expected variance. +class DiscretizationLoss(nn.Module): + def __init__(self, discrete_bins, dim, expected_variance, store_past=0): + super().__init__() + self.discrete_bins = discrete_bins + self.dim = dim + self.dist = torch.distributions.Normal(0, scale=expected_variance) + if store_past > 0: + self.record_past = True + self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device="cpu")) + self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device="cpu")) + self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins)) + else: + self.record_past = False + + def forward(self, x): + other_dims = set(range(len(x.shape))) - set([self.dim]) + averaged = x.sum(dim=tuple(other_dims)) / x.sum() + averaged = averaged - averaged.mean() + + if self.record_past: + acc_count = self.accumulator.shape[0] + avg = averaged.detach().clone() + if self.accumulator_filled > 0: + averaged = torch.mean(self.accumulator, dim=0) * (acc_count - 1) / acc_count + averaged / acc_count + + # Also push averaged into the accumulator. + self.accumulator[self.accumulator_index] = avg + self.accumulator_index += 1 + if self.accumulator_index >= acc_count: + self.accumulator_index *= 0 + if self.accumulator_filled <= 0: + self.accumulator_filled += 1 + + return torch.sum(-self.dist.log_prob(averaged)) + + +class ResBlock(nn.Module): + def __init__(self, chan, conv, activation): + super().__init__() + self.net = nn.Sequential( + conv(chan, chan, 3, padding=1), + activation(), + conv(chan, chan, 3, padding=1), + activation(), + conv(chan, chan, 1), + ) + + def forward(self, x): + return self.net(x) + x + + +class UpsampledConv(nn.Module): + def __init__(self, conv, *args, **kwargs): + super().__init__() + assert "stride" in kwargs.keys() + self.stride = kwargs["stride"] + del kwargs["stride"] + self.conv = conv(*args, **kwargs) + + def forward(self, x): + up = nn.functional.interpolate(x, scale_factor=self.stride, mode="nearest") + return self.conv(up) + + +# DiscreteVAE partially derived from lucidrains DALLE implementation +# Credit: https://github.com/lucidrains/DALLE-pytorch +class DiscreteVAE(nn.Module): + def __init__( + self, + positional_dims=2, + num_tokens=512, + codebook_dim=512, + num_layers=3, + num_resnet_blocks=0, + hidden_dim=64, + channels=3, + stride=2, + kernel_size=4, + use_transposed_convs=True, + encoder_norm=False, + activation="relu", + smooth_l1_loss=False, + straight_through=False, + normalization=None, # ((0.5,) * 3, (0.5,) * 3), + record_codes=False, + discretization_loss_averaging_steps=100, + lr_quantizer_args={}, + ): + super().__init__() + has_resblocks = num_resnet_blocks > 0 + + self.num_tokens = num_tokens + self.num_layers = num_layers + self.straight_through = straight_through + self.positional_dims = positional_dims + self.discrete_loss = DiscretizationLoss( + num_tokens, 2, 1 / (num_tokens * 2), discretization_loss_averaging_steps + ) + + assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now. + if positional_dims == 2: + conv = nn.Conv2d + conv_transpose = nn.ConvTranspose2d + else: + conv = nn.Conv1d + conv_transpose = nn.ConvTranspose1d + if not use_transposed_convs: + conv_transpose = functools.partial(UpsampledConv, conv) + + if activation == "relu": + act = nn.ReLU + elif activation == "silu": + act = nn.SiLU + else: + assert NotImplementedError() + + enc_layers = [] + dec_layers = [] + + if num_layers > 0: + enc_chans = [hidden_dim * 2**i for i in range(num_layers)] + dec_chans = list(reversed(enc_chans)) + + enc_chans = [channels, *enc_chans] + + dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] + dec_chans = [dec_init_chan, *dec_chans] + + enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) + + pad = (kernel_size - 1) // 2 + for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): + enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride=stride, padding=pad), act())) + if encoder_norm: + enc_layers.append(nn.GroupNorm(8, enc_out)) + dec_layers.append( + nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride=stride, padding=pad), act()) + ) + dec_out_chans = dec_chans[-1] + innermost_dim = dec_chans[0] + else: + enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act())) + dec_out_chans = hidden_dim + innermost_dim = hidden_dim + + for _ in range(num_resnet_blocks): + dec_layers.insert(0, ResBlock(innermost_dim, conv, act)) + enc_layers.append(ResBlock(innermost_dim, conv, act)) + + if num_resnet_blocks > 0: + dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1)) + + enc_layers.append(conv(innermost_dim, codebook_dim, 1)) + dec_layers.append(conv(dec_out_chans, channels, 1)) + + self.encoder = nn.Sequential(*enc_layers) + self.decoder = nn.Sequential(*dec_layers) + + self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss + self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True) + + # take care of normalization within class + self.normalization = normalization + self.record_codes = record_codes + if record_codes: + self.codes = torch.zeros((1228800,), dtype=torch.long) + self.code_ind = 0 + self.total_codes = 0 + self.internal_step = 0 + + def norm(self, images): + if not self.normalization is not None: + return images + + means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization) + arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()" + means, stds = map(lambda t: rearrange(t, arrange), (means, stds)) + images = images.clone() + images.sub_(means).div_(stds) + return images + + def get_debug_values(self, step, __): + if self.record_codes and self.total_codes > 0: + # Report annealing schedule + return {"histogram_codes": self.codes[: self.total_codes]} + else: + return {} + + @torch.no_grad() + @eval_decorator + def get_codebook_indices(self, images): + img = self.norm(images) + logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) + sampled, codes, _ = self.codebook(logits) + self.log_codes(codes) + return codes + + def decode(self, img_seq): + self.log_codes(img_seq) + if hasattr(self.codebook, "embed_code"): + image_embeds = self.codebook.embed_code(img_seq) + else: + image_embeds = F.embedding(img_seq, self.codebook.codebook) + b, n, d = image_embeds.shape + + kwargs = {} + if self.positional_dims == 1: + arrange = "b n d -> b d n" + else: + h = w = int(sqrt(n)) + arrange = "b (h w) d -> b d h w" + kwargs = {"h": h, "w": w} + image_embeds = rearrange(image_embeds, arrange, **kwargs) + images = [image_embeds] + for layer in self.decoder: + images.append(layer(images[-1])) + return images[-1], images[-2] + + def infer(self, img): + img = self.norm(img) + logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) + sampled, codes, commitment_loss = self.codebook(logits) + return self.decode(codes) + + # Note: This module is not meant to be run in forward() except while training. It has special logic which performs + # evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially + # more lossy (but useful for determining network performance). + def forward(self, img): + img = self.norm(img) + logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) + sampled, codes, commitment_loss = self.codebook(logits) + sampled = sampled.permute((0, 3, 1, 2) if len(img.shape) == 4 else (0, 2, 1)) + + if self.training: + out = sampled + for d in self.decoder: + out = d(out) + self.log_codes(codes) + else: + # This is non-differentiable, but gives a better idea of how the network is actually performing. + out, _ = self.decode(codes) + + # reconstruction loss + out = out[..., :img.shape[-1]] + recon_loss = self.loss_fn(img, out, reduction="mean") + ssim_loss = torch.zeros(size=(1,)).cuda() + + return recon_loss, ssim_loss, commitment_loss, out + + def log_codes(self, codes): + # This is so we can debug the distribution of codes being learned. + if self.record_codes and self.internal_step % 10 == 0: + codes = codes.flatten() + l = codes.shape[0] + i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + self.codes[i : i + l] = codes.cpu() + self.code_ind = self.code_ind + l + if self.code_ind >= self.codes.shape[0]: + self.code_ind = 0 + self.total_codes += 1 + self.internal_step += 1 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3276b059d0c24590ca47b52d1d598993eaf9e62f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,30 @@ +accelerate==1.8.1 +descript-audiotools==0.7.2 +transformers==4.52.1 +tokenizers==0.21.0 +cn2an==0.5.22 +ffmpeg-python==0.2.0 +Cython==3.0.7 +g2p-en==2.1.0 +jieba==0.42.1 +json5==0.10.0 +keras==2.9.0 +numba==0.58.1 +numpy==1.26.2 +pandas==2.1.3 +matplotlib==3.8.2 +munch==4.0.0 +opencv-python==4.9.0.80 +tensorboard==2.9.1 +librosa==0.10.2.post1 +safetensors==0.5.2 +deepspeed==0.17.1 +modelscope==1.27.0 +omegaconf +sentencepiece +gradio +tqdm +textstat + +WeTextProcessing; platform_machine != "Darwin" +wetext; platform_system == "Darwin" diff --git a/tests/padding_test.py b/tests/padding_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb67d031dd946e0f6a64264f9551f656976b07c --- /dev/null +++ b/tests/padding_test.py @@ -0,0 +1,99 @@ +import torch +import torchaudio +from indextts.infer import IndexTTS +from indextts.utils.feature_extractors import MelSpectrogramFeatures +from torch.nn import functional as F + +if __name__ == "__main__": + """ + Test the padding of text tokens in inference. + ``` + python tests/padding_test.py checkpoints + python tests/padding_test.py IndexTTS-1.5 + ``` + """ + import transformers + transformers.set_seed(42) + import sys + sys.path.append("..") + if len(sys.argv) > 1: + model_dir = sys.argv[1] + else: + model_dir = "checkpoints" + audio_prompt="tests/sample_prompt.wav" + tts = IndexTTS(cfg_path=f"{model_dir}/config.yaml", model_dir=model_dir, is_fp16=False, use_cuda_kernel=False) + text = "晕 XUAN4 是 一 种 not very good GAN3 觉" + text_tokens = tts.tokenizer.encode(text) + text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=tts.device).unsqueeze(0) # [1, L] + + audio, sr = torchaudio.load(audio_prompt) + audio = torch.mean(audio, dim=0, keepdim=True) + audio = torchaudio.transforms.Resample(sr, 24000)(audio) + auto_conditioning = MelSpectrogramFeatures()(audio).to(tts.device) + cond_mel_lengths = torch.tensor([auto_conditioning.shape[-1]]).to(tts.device) + with torch.no_grad(): + kwargs = { + "cond_mel_lengths": cond_mel_lengths, + "do_sample": False, + "top_p": 0.8, + "top_k": None, + "temperature": 1.0, + "num_return_sequences": 1, + "length_penalty": 0.0, + "num_beams": 1, + "repetition_penalty": 10.0, + "max_generate_length": 100, + } + # baseline for non-pad + baseline = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs) + baseline = baseline.squeeze(0) + print("Inference padded text tokens...") + pad_text_tokens = [ + F.pad(text_tokens, (8, 0), value=0), # left bos + F.pad(text_tokens, (0, 8), value=1), # right eos + F.pad(F.pad(text_tokens, (4, 0), value=0), (0, 4), value=1), # both side + F.pad(F.pad(text_tokens, (6, 0), value=0), (0, 2), value=1), + F.pad(F.pad(text_tokens, (0, 4), value=0), (0, 4), value=1), + ] + output_for_padded = [] + for t in pad_text_tokens: + # test for each padded text + out = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs) + output_for_padded.append(out.squeeze(0)) + # batched inference + print("Inference padded text tokens as one batch...") + batched_text_tokens = torch.cat(pad_text_tokens, dim=0).to(tts.device) + assert len(pad_text_tokens) == batched_text_tokens.shape[0] and batched_text_tokens.ndim == 2 + batch_output = tts.gpt.inference_speech(auto_conditioning, batched_text_tokens, **kwargs) + del pad_text_tokens + mismatch_idx = [] + print("baseline:", baseline.shape, baseline) + print("--"*10) + print("baseline vs padded output:") + for i in range(len(output_for_padded)): + if not baseline.equal(output_for_padded[i]): + mismatch_idx.append(i) + + if len(mismatch_idx) > 0: + print("mismatch:", mismatch_idx) + for i in mismatch_idx: + print(f"[{i}]: {output_for_padded[i]}") + else: + print("all matched") + + del output_for_padded + print("--"*10) + print("baseline vs batched output:") + mismatch_idx = [] + for i in range(batch_output.shape[0]): + if not baseline.equal(batch_output[i]): + mismatch_idx.append(i) + if len(mismatch_idx) > 0: + print("mismatch:", mismatch_idx) + for i in mismatch_idx: + print(f"[{i}]: {batch_output[i]}") + + else: + print("all matched") + + print("Test finished.") \ No newline at end of file diff --git a/tests/regression_test.py b/tests/regression_test.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf542380ee0bb6b65c928eb5cdb95eaaa9e1371 --- /dev/null +++ b/tests/regression_test.py @@ -0,0 +1,42 @@ +from indextts.infer import IndexTTS + +if __name__ == "__main__": + prompt_wav="tests/sample_prompt.wav" + tts = IndexTTS(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=True, use_cuda_kernel=False) + # 单音频推理测试 + text="晕 XUAN4 是 一 种 GAN3 觉" + tts.infer(audio_prompt=prompt_wav, text=text, output_path=f"outputs/{text[:20]}.wav", verbose=True) + text='大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!' + tts.infer(audio_prompt=prompt_wav, text=text, output_path=f"outputs/{text[:20]}.wav", verbose=True) + text="There is a vehicle arriving in dock number 7?" + tts.infer(audio_prompt=prompt_wav, text=text, output_path=f"outputs/{text[:20]}.wav", verbose=True) + text = "“我爱你!”的英语是“I love you!”" + tts.infer(audio_prompt=prompt_wav, text=text, output_path=f"outputs/{text[:20]}.wav", verbose=True) + text = "Joseph Gordon-Levitt is an American actor" + tts.infer(audio_prompt=prompt_wav, text=text, output_path=f"outputs/{text[:20]}.wav", verbose=True) + text = "约瑟夫·高登-莱维特是美国演员" + tts.infer(audio_prompt=prompt_wav, text=text, output_path=f"outputs/{text[:20]}.wav", verbose=True) + text = "蒂莫西·唐纳德·库克(英文名:Timothy Donald Cook),通称蒂姆·库克(Tim Cook),现任苹果公司首席执行官。" + tts.infer(audio_prompt=prompt_wav, text=text, output_path="outputs/蒂莫西·唐纳德·库克.wav", verbose=True) + # 并行推理测试 + text="亲爱的伙伴们,大家好!每一次的努力都是为了更好的未来,要善于从失败中汲取经验,让我们一起勇敢前行,迈向更加美好的明天!" + tts.infer_fast(audio_prompt=prompt_wav, text=text, output_path=f"outputs/{text[:20]}.wav", verbose=True) + text="The weather is really nice today, perfect for studying at home.Thank you!" + tts.infer_fast(audio_prompt=prompt_wav, text=text, output_path=f"outputs/{text[:20]}.wav", verbose=True) + text='''叶远随口答应一声,一定帮忙云云。 +教授看叶远的样子也知道,这事情多半是黄了。 +谁得到这样的东西也不会轻易贡献出来,这是很大的一笔财富。 +叶远回来后,又自己做了几次试验,发现空间湖水对一些外伤也有很大的帮助。 +找来一只断了腿的兔子,喝下空间湖水,一天时间,兔子就完全好了。 +还想多做几次试验,可是身边没有试验的对象,就先放到一边,了解空间湖水可以饮用,而且对人有利,这些就足够了。 +感谢您的收听,下期再见! + '''.replace("\n", "") + tts.infer_fast(audio_prompt=prompt_wav, text=text, output_path=f"outputs/{text[:20]}.wav", verbose=True) + # 长文本推理测试 + text = """《盗梦空间》是由美国华纳兄弟影片公司出品的电影,由克里斯托弗·诺兰执导并编剧, +莱昂纳多·迪卡普里奥、玛丽昂·歌迪亚、约瑟夫·高登-莱维特、艾利奥特·佩吉、汤姆·哈迪等联袂主演, +2010年7月16日在美国上映,2010年9月1日在中国内地上映,2020年8月28日在中国内地重映。 +影片剧情游走于梦境与现实之间,被定义为“发生在意识结构内的当代动作科幻片”, +讲述了由莱昂纳多·迪卡普里奥扮演的造梦师,带领特工团队进入他人梦境,从他人的潜意识中盗取机密,并重塑他人梦境的故事。 +""".replace("\n", "") + tts.infer_fast(audio_prompt=prompt_wav, text=text, output_path=f"outputs/{text[:20]}.wav", verbose=True) diff --git a/tools/i18n/i18n.py b/tools/i18n/i18n.py new file mode 100644 index 0000000000000000000000000000000000000000..484e236fff6712a6ec5466ce6530868fd12a545c --- /dev/null +++ b/tools/i18n/i18n.py @@ -0,0 +1,36 @@ +import json +import locale +import os + +I18N_JSON_DIR : os.PathLike = os.path.join(os.path.dirname(os.path.relpath(__file__)), 'locale') + +def load_language_list(language): + with open(os.path.join(I18N_JSON_DIR, f"{language}.json"), "r", encoding="utf-8") as f: + language_list = json.load(f) + return language_list + +def scan_language_list(): + language_list = [] + for name in os.listdir(I18N_JSON_DIR): + if name.endswith(".json"):language_list.append(name.split('.')[0]) + return language_list + +class I18nAuto: + def __init__(self, language=None): + if language in ["Auto", None]: + language = locale.getdefaultlocale()[0] + # getlocale can't identify the system's language ((None, None)) + if not os.path.exists(os.path.join(I18N_JSON_DIR, f"{language}.json")): + language = "en_US" + self.language = language + self.language_map = load_language_list(language) + + def __call__(self, key): + return self.language_map.get(key, key) + + def __repr__(self): + return "Use Language: " + self.language + +if __name__ == "__main__": + i18n = I18nAuto(language='en_US') + print(i18n) \ No newline at end of file diff --git a/tools/i18n/locale/en_US.json b/tools/i18n/locale/en_US.json new file mode 100644 index 0000000000000000000000000000000000000000..8ef3dafeb2ad1ed9eae12fbebf3df9ba3898225a --- /dev/null +++ b/tools/i18n/locale/en_US.json @@ -0,0 +1,46 @@ +{ + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.": "This software is open-sourced under the MIT License. The author has no control over the software, and users of the software, as well as those who distribute the audio generated by the software, assume full responsibility.", + "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "If you do not agree to these terms, you are not permitted to use or reference any code or files within the software package. For further details, please refer to the LICENSE file in the root directory.", + "时长必须为正数": "Duration must be a positive number", + "请输入有效的浮点数": "Please enter a valid floating-point number", + "使用情感参考音频": "Use emotion reference audio", + "使用情感向量控制": "Use emotion vector", + "使用情感描述文本控制": "Use text description to control emotion", + "上传情感参考音频": "Upload emotion reference audio", + "情感权重": "Emotion control weight", + "喜": "Happy", + "怒": "Angry", + "哀": "Sad", + "惧": "Fear", + "厌恶": "Hate", + "低落": "Low", + "惊喜": "Surprise", + "平静": "Neutral", + "情感描述文本": "Emotion description", + "请输入情感描述文本": "Please input emotion description", + "高级生成参数设置": "Advanced generation parameter settings", + "情感向量之和不能超过1.5,请调整后重试。": "The sum of the emotion vectors cannot exceed 1.5. Please adjust and try again.", + "音色参考音频": "Voice reference", + "音频生成": "Speech Synthesis", + "文本": "Text", + "生成语音": "Synthesize", + "生成结果": "Synthesis Result", + "功能设置": "Settings", + "分句设置": "Sentence segmentation settings", + "参数会影响音频质量和生成速度": "Parameters below affect audio quality and generation speed", + "分句最大Token数": "Max tokens per sentence", + "建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高": "Recommended between 80 and 200. The larger the value, the longer the sentences; the smaller the value, the more fragmented the sentences. Values that are too small or too large may lead to poor audio quality.", + "预览分句结果": "Preview sentence segmentation result", + "序号": "Index", + "分句内容": "Content", + "Token数": "Token Count", + "情感控制方式": "Emotion control method", + "GPT2 采样设置": "GPT-2 Sampling Configuration", + "参数会影响音频多样性和生成速度详见": "Influence both the diversity of the generated audio and the generation speed. For further details, refer to", + "请上传情感参考音频": "Please upload emotion reference audio", + "当前模型版本": "Current model version ", + "请输入目标文本": "Please input text to synthesize", + "例如:高兴,愤怒,悲伤等": "e.g., happy, angry, sad, etc.", + "与音色参考音频相同": "Same as the voice reference", + "情感随机采样": "Random emotion sampling" +} \ No newline at end of file diff --git a/tools/i18n/locale/zh_CN.json b/tools/i18n/locale/zh_CN.json new file mode 100644 index 0000000000000000000000000000000000000000..26a75968dcabe4c55477d7ab6c0953638111b5e9 --- /dev/null +++ b/tools/i18n/locale/zh_CN.json @@ -0,0 +1,40 @@ +{ + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.": "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.", + "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.": "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.", + "时长必须为正数": "时长必须为正数", + "请输入有效的浮点数": "请输入有效的浮点数", + "使用情感参考音频": "使用情感参考音频", + "使用情感向量控制": "使用情感向量控制", + "使用情感描述文本控制": "使用情感描述文本控制", + "上传情感参考音频": "上传情感参考音频", + "情感权重": "情感权重", + "喜": "喜", + "怒": "怒", + "哀": "哀", + "惧": "惧", + "厌恶": "厌恶", + "低落": "低落", + "惊喜": "惊喜", + "平静": "平静", + "情感描述文本": "情感描述文本", + "请输入情感描述文本": "请输入情感描述文本", + "高级生成参数设置": "高级生成参数设置", + "情感向量之和不能超过1.5,请调整后重试。": "情感向量之和不能超过1.5,请调整后重试。", + "音色参考音频": "音色参考音频", + "音频生成": "音频生成", + "文本": "文本", + "生成语音": "生成语音", + "生成结果": "生成结果", + "功能设置": "功能设置", + "分句设置": "分句设置", + "参数会影响音频质量和生成速度": "参数会影响音频质量和生成速度", + "分句最大Token数": "分句最大Token数", + "建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高": "建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高", + "预览分句结果": "预览分句结果", + "序号": "序号", + "分句内容": "分句内容", + "Token数": "Token数", + "情感控制方式": "情感控制方式", + "GPT2 采样设置": "GPT2 采样设置", + "参数会影响音频多样性和生成速度详见": "参数会影响音频多样性和生成速度详见" +} \ No newline at end of file diff --git a/tools/i18n/scan_i18n.py b/tools/i18n/scan_i18n.py new file mode 100644 index 0000000000000000000000000000000000000000..d2bd12b7de29491f389cdf9a1f256ea18c20a3a2 --- /dev/null +++ b/tools/i18n/scan_i18n.py @@ -0,0 +1,131 @@ +import ast +import glob +import json +import os +from collections import OrderedDict + +I18N_JSON_DIR : os.PathLike = os.path.join(os.path.dirname(os.path.relpath(__file__)), 'locale') +DEFAULT_LANGUAGE: str = "zh_CN" # 默认语言 +TITLE_LEN : int = 60 # 标题显示长度 +KEY_LEN : int = 30 # 键名显示长度 +SHOW_KEYS : bool = False # 是否显示键信息 +SORT_KEYS : bool = False # 是否按全局键名写入文件 + +def extract_i18n_strings(node): + i18n_strings = [] + + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id == "i18n" + ): + for arg in node.args: + if isinstance(arg, ast.Str): + i18n_strings.append(arg.s) + + for child_node in ast.iter_child_nodes(node): + i18n_strings.extend(extract_i18n_strings(child_node)) + + return i18n_strings + +def scan_i18n_strings(): + """ + scan the directory for all .py files (recursively) + for each file, parse the code into an AST + for each AST, extract the i18n strings + """ + strings = [] + print(" Scanning Files and Extracting i18n Strings ".center(TITLE_LEN, "=")) + for filename in glob.iglob("**/*.py", recursive=True): + try: + with open(filename, "r", encoding="utf-8") as f: + code = f.read() + if "I18nAuto" in code: + tree = ast.parse(code) + i18n_strings = extract_i18n_strings(tree) + print(f"{filename.ljust(KEY_LEN*3//2)}: {len(i18n_strings)}") + if SHOW_KEYS: + print("\n".join([s for s in i18n_strings])) + strings.extend(i18n_strings) + except Exception as e: + print(f"\033[31m[Failed] Error occur at {filename}: {e}\033[0m") + + code_keys = set(strings) + print(f"{'Total Unique'.ljust(KEY_LEN*3//2)}: {len(code_keys)}") + return code_keys + +def update_i18n_json(json_file, standard_keys): + standard_keys = sorted(standard_keys) + print(f" Process {json_file} ".center(TITLE_LEN, "=")) + # 读取 JSON 文件 + with open(json_file, "r", encoding="utf-8") as f: + json_data = json.load(f, object_pairs_hook=OrderedDict) + # 打印处理前的 JSON 条目数 + len_before = len(json_data) + print(f"{'Total Keys'.ljust(KEY_LEN)}: {len_before}") + # 识别缺失的键并补全 + miss_keys = set(standard_keys) - set(json_data.keys()) + if len(miss_keys) > 0: + print(f"{'Missing Keys (+)'.ljust(KEY_LEN)}: {len(miss_keys)}") + for key in miss_keys: + if DEFAULT_LANGUAGE in json_file: + # 默认语言的键值相同. + json_data[key] = key + else: + # 其他语言的值设置为 #! + 键名以标注未被翻译. + json_data[key] = "#!" + key + if SHOW_KEYS: + print(f"{'Added Missing Key'.ljust(KEY_LEN)}: {key}") + # 识别多余的键并删除 + diff_keys = set(json_data.keys()) - set(standard_keys) + if len(diff_keys) > 0: + print(f"{'Unused Keys (-)'.ljust(KEY_LEN)}: {len(diff_keys)}") + for key in diff_keys: + del json_data[key] + if SHOW_KEYS: + print(f"{'Removed Unused Key'.ljust(KEY_LEN)}: {key}") + # 按键顺序排序 + json_data = OrderedDict( + sorted( + json_data.items(), + key=lambda x: ( + list(standard_keys).index(x[0]) if x[0] in standard_keys and not x[1].startswith('#!') else len(json_data), + ) + ) + ) + # 打印处理后的 JSON 条目数 + if len(miss_keys) != 0 or len(diff_keys) != 0: + print(f"{'Total Keys (After)'.ljust(KEY_LEN)}: {len(json_data)}") + # 识别有待翻译的键 + num_miss_translation = 0 + duplicate_items = {} + for key, value in json_data.items(): + if value.startswith("#!"): + num_miss_translation += 1 + if SHOW_KEYS: + print(f"{'Missing Translation'.ljust(KEY_LEN)}: {key}") + if value in duplicate_items: + duplicate_items[value].append(key) + else: + duplicate_items[value] = [key] + # 打印是否有重复的值 + for value, keys in duplicate_items.items(): + if len(keys) > 1: + print("\n".join([f"\033[31m{'[Failed] Duplicate Value'.ljust(KEY_LEN)}: {key} -> {value}\033[0m" for key in keys])) + + if num_miss_translation > 0: + print(f"\033[31m{'[Failed] Missing Translation'.ljust(KEY_LEN)}: {num_miss_translation}\033[0m") + else: + print(f"\033[32m[Passed] All Keys Translated\033[0m") + # 将处理后的结果写入 JSON 文件 + with open(json_file, "w", encoding="utf-8") as f: + json.dump(json_data, f, ensure_ascii=False, indent=4, sort_keys=SORT_KEYS) + f.write("\n") + print(f" Updated {json_file} ".center(TITLE_LEN, "=") + '\n') + +if __name__ == "__main__": + code_keys = scan_i18n_strings() + for json_file in os.listdir(I18N_JSON_DIR): + if json_file.endswith(r".json"): + json_file = os.path.join(I18N_JSON_DIR, json_file) + update_i18n_json(json_file, code_keys) \ No newline at end of file diff --git a/webui.py b/webui.py new file mode 100644 index 0000000000000000000000000000000000000000..648bb3bf283d5e078d26958b4d714efc3c05d57c --- /dev/null +++ b/webui.py @@ -0,0 +1,341 @@ +import json +import logging +import os +import sys +import threading +import time + +import warnings + +import pandas as pd + +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(current_dir) +sys.path.append(os.path.join(current_dir, "indextts")) + +import argparse +parser = argparse.ArgumentParser(description="IndexTTS WebUI") +parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode") +parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on") +parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the web UI on") +parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory") +parser.add_argument("--is_fp16", action="store_true", default=False, help="Fp16 infer") +cmd_args = parser.parse_args() +from huggingface_hub import snapshot_download +os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache' +snapshot_download(repo_id="IndexTeam/IndexTTS-2", local_dir="./checkpoints") + +if not os.path.exists(cmd_args.model_dir): + print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.") + sys.exit(1) + +for file in [ + "bpe.model", + "gpt.pth", + "config.yaml", + "s2mel.pth", + "wav2vec2bert_stats.pt" +]: + file_path = os.path.join(cmd_args.model_dir, file) + if not os.path.exists(file_path): + print(f"Required file {file_path} does not exist. Please download it.") + sys.exit(1) + +import gradio as gr +from indextts import infer +from indextts.infer_v2 import IndexTTS2 +from tools.i18n.i18n import I18nAuto +from modelscope.hub import api + +i18n = I18nAuto(language="Auto") +MODE = 'local' +tts = IndexTTS2(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),is_fp16=cmd_args.is_fp16) + +# 支持的语言列表 +LANGUAGES = { + "中文": "zh_CN", + "English": "en_US" +} +EMO_CHOICES = [i18n("与音色参考音频相同"), + i18n("使用情感参考音频"), + i18n("使用情感向量控制"), + i18n("使用情感描述文本控制")] +os.makedirs("outputs/tasks",exist_ok=True) +os.makedirs("prompts",exist_ok=True) + +MAX_LENGTH_TO_USE_SPEED = 70 +with open("examples/cases.jsonl", "r", encoding="utf-8") as f: + example_cases = [] + for line in f: + line = line.strip() + if not line: + continue + example = json.loads(line) + if example.get("emo_audio",None): + emo_audio_path = os.path.join("examples",example["emo_audio"]) + else: + emo_audio_path = None + example_cases.append([os.path.join("examples", example.get("prompt_audio", "sample_prompt.wav")), + EMO_CHOICES[example.get("emo_mode",0)], + example.get("text"), + emo_audio_path, + example.get("emo_weight",1.0), + example.get("emo_text",""), + example.get("emo_vec_1",0), + example.get("emo_vec_2",0), + example.get("emo_vec_3",0), + example.get("emo_vec_4",0), + example.get("emo_vec_5",0), + example.get("emo_vec_6",0), + example.get("emo_vec_7",0), + example.get("emo_vec_8",0)] + ) + + +def gen_single(emo_control_method,prompt, text, + emo_ref_path, emo_weight, + vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8, + emo_text,emo_random, + max_text_tokens_per_sentence=120, + *args, progress=gr.Progress()): + output_path = None + if not output_path: + output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav") + # set gradio progress + tts.gr_progress = progress + do_sample, top_p, top_k, temperature, \ + length_penalty, num_beams, repetition_penalty, max_mel_tokens = args + kwargs = { + "do_sample": bool(do_sample), + "top_p": float(top_p), + "top_k": int(top_k) if int(top_k) > 0 else None, + "temperature": float(temperature), + "length_penalty": float(length_penalty), + "num_beams": num_beams, + "repetition_penalty": float(repetition_penalty), + "max_mel_tokens": int(max_mel_tokens), + # "typical_sampling": bool(typical_sampling), + # "typical_mass": float(typical_mass), + } + if type(emo_control_method) is not int: + emo_control_method = emo_control_method.value + if emo_control_method == 0: + emo_ref_path = None + emo_weight = 1.0 + if emo_control_method == 1: + emo_weight = emo_weight + if emo_control_method == 2: + vec = [vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8] + vec_sum = sum([vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8]) + if vec_sum > 1.5: + gr.Warning(i18n("情感向量之和不能超过1.5,请调整后重试。")) + return + else: + vec = None + + print(f"Emo control mode:{emo_control_method},vec:{vec}") + output = tts.infer(spk_audio_prompt=prompt, text=text, + output_path=output_path, + emo_audio_prompt=emo_ref_path, emo_alpha=emo_weight, + emo_vector=vec, + use_emo_text=(emo_control_method==3), emo_text=emo_text,use_random=emo_random, + verbose=cmd_args.verbose, + max_text_tokens_per_sentence=int(max_text_tokens_per_sentence), + **kwargs) + return gr.update(value=output,visible=True) + +def update_prompt_audio(): + update_button = gr.update(interactive=True) + return update_button + +with gr.Blocks(title="IndexTTS Demo") as demo: + mutex = threading.Lock() + gr.HTML(''' +

IndexTTS2: A Breakthrough in Emotionally Expressive and Duration-Controlled Auto-Regressive Zero-Shot Text-to-Speech

+

+ +

+ ''') + with gr.Tab(i18n("音频生成")): + with gr.Row(): + os.makedirs("prompts",exist_ok=True) + prompt_audio = gr.Audio(label=i18n("音色参考音频"),key="prompt_audio", + sources=["upload","microphone"],type="filepath") + prompt_list = os.listdir("prompts") + default = '' + if prompt_list: + default = prompt_list[0] + with gr.Column(): + input_text_single = gr.TextArea(label=i18n("文本"),key="input_text_single", placeholder=i18n("请输入目标文本"), info=f"{i18n('当前模型版本')}{tts.model_version or '1.0'}") + gen_button = gr.Button(i18n("生成语音"), key="gen_button",interactive=True) + output_audio = gr.Audio(label=i18n("生成结果"), visible=True,key="output_audio") + with gr.Accordion(i18n("功能设置")): + # 情感控制选项部分 + with gr.Row(): + emo_control_method = gr.Radio( + choices=EMO_CHOICES, + type="index", + value=EMO_CHOICES[0],label=i18n("情感控制方式")) + # 情感参考音频部分 + with gr.Group(visible=False) as emotion_reference_group: + with gr.Row(): + emo_upload = gr.Audio(label=i18n("上传情感参考音频"), type="filepath") + + with gr.Row(): + emo_weight = gr.Slider(label=i18n("情感权重"), minimum=0.0, maximum=1.6, value=0.8, step=0.01) + + # 情感随机采样 + with gr.Row(): + emo_random = gr.Checkbox(label=i18n("情感随机采样"),value=False,visible=False) + + # 情感向量控制部分 + with gr.Group(visible=False) as emotion_vector_group: + with gr.Row(): + with gr.Column(): + vec1 = gr.Slider(label=i18n("喜"), minimum=0.0, maximum=1.4, value=0.0, step=0.05) + vec2 = gr.Slider(label=i18n("怒"), minimum=0.0, maximum=1.4, value=0.0, step=0.05) + vec3 = gr.Slider(label=i18n("哀"), minimum=0.0, maximum=1.4, value=0.0, step=0.05) + vec4 = gr.Slider(label=i18n("惧"), minimum=0.0, maximum=1.4, value=0.0, step=0.05) + with gr.Column(): + vec5 = gr.Slider(label=i18n("厌恶"), minimum=0.0, maximum=1.4, value=0.0, step=0.05) + vec6 = gr.Slider(label=i18n("低落"), minimum=0.0, maximum=1.4, value=0.0, step=0.05) + vec7 = gr.Slider(label=i18n("惊喜"), minimum=0.0, maximum=1.4, value=0.0, step=0.05) + vec8 = gr.Slider(label=i18n("平静"), minimum=0.0, maximum=1.4, value=0.0, step=0.05) + + with gr.Group(visible=False) as emo_text_group: + with gr.Row(): + emo_text = gr.Textbox(label=i18n("情感描述文本"), placeholder=i18n("请输入情感描述文本"), value="", info=i18n("例如:高兴,愤怒,悲伤等")) + + with gr.Accordion(i18n("高级生成参数设置"), open=False): + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown(f"**{i18n('GPT2 采样设置')}** _{i18n('参数会影响音频多样性和生成速度详见')}[Generation strategies](https://huggingface.co/docs/transformers/main/en/generation_strategies)_") + with gr.Row(): + do_sample = gr.Checkbox(label="do_sample", value=True, info="是否进行采样") + temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=0.8, step=0.1) + with gr.Row(): + top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01) + top_k = gr.Slider(label="top_k", minimum=0, maximum=100, value=30, step=1) + num_beams = gr.Slider(label="num_beams", value=3, minimum=1, maximum=10, step=1) + with gr.Row(): + repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1) + length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1) + max_mel_tokens = gr.Slider(label="max_mel_tokens", value=1500, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10, info="生成Token最大数量,过小导致音频被截断", key="max_mel_tokens") + # with gr.Row(): + # typical_sampling = gr.Checkbox(label="typical_sampling", value=False, info="不建议使用") + # typical_mass = gr.Slider(label="typical_mass", value=0.9, minimum=0.0, maximum=1.0, step=0.1) + with gr.Column(scale=2): + gr.Markdown(f'**{i18n("分句设置")}** _{i18n("参数会影响音频质量和生成速度")}_') + with gr.Row(): + max_text_tokens_per_sentence = gr.Slider( + label=i18n("分句最大Token数"), value=120, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_sentence", + info=i18n("建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高"), + ) + with gr.Accordion(i18n("预览分句结果"), open=True) as sentences_settings: + sentences_preview = gr.Dataframe( + headers=[i18n("序号"), i18n("分句内容"), i18n("Token数")], + key="sentences_preview", + wrap=True, + ) + advanced_params = [ + do_sample, top_p, top_k, temperature, + length_penalty, num_beams, repetition_penalty, max_mel_tokens, + # typical_sampling, typical_mass, + ] + + if len(example_cases) > 0: + gr.Examples( + examples=example_cases, + examples_per_page=20, + inputs=[prompt_audio, + emo_control_method, + input_text_single, + emo_upload, + emo_weight, + emo_text, + vec1,vec2,vec3,vec4,vec5,vec6,vec7,vec8] + ) + + def on_input_text_change(text, max_tokens_per_sentence): + if text and len(text) > 0: + text_tokens_list = tts.tokenizer.tokenize(text) + + sentences = tts.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=int(max_tokens_per_sentence)) + data = [] + for i, s in enumerate(sentences): + sentence_str = ''.join(s) + tokens_count = len(s) + data.append([i, sentence_str, tokens_count]) + return { + sentences_preview: gr.update(value=data, visible=True, type="array"), + } + else: + df = pd.DataFrame([], columns=[i18n("序号"), i18n("分句内容"), i18n("Token数")]) + return { + sentences_preview: gr.update(value=df), + } + def on_method_select(emo_control_method): + if emo_control_method == 1: + return (gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False) + ) + elif emo_control_method == 2: + return (gr.update(visible=False), + gr.update(visible=True), + gr.update(visible=True), + gr.update(visible=False) + ) + elif emo_control_method == 3: + return (gr.update(visible=False), + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=True) + ) + else: + return (gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False) + ) + + emo_control_method.select(on_method_select, + inputs=[emo_control_method], + outputs=[emotion_reference_group, + emo_random, + emotion_vector_group, + emo_text_group] + ) + + input_text_single.change( + on_input_text_change, + inputs=[input_text_single, max_text_tokens_per_sentence], + outputs=[sentences_preview] + ) + max_text_tokens_per_sentence.change( + on_input_text_change, + inputs=[input_text_single, max_text_tokens_per_sentence], + outputs=[sentences_preview] + ) + prompt_audio.upload(update_prompt_audio, + inputs=[], + outputs=[gen_button]) + + gen_button.click(gen_single, + inputs=[emo_control_method,prompt_audio, input_text_single, emo_upload, emo_weight, + vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8, + emo_text,emo_random, + max_text_tokens_per_sentence, + *advanced_params, + ], + outputs=[output_audio]) + + + +if __name__ == "__main__": + demo.queue(20) + demo.launch(server_name=cmd_args.host, server_port=cmd_args.port)