{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append(\"../\")\n", "\n", "import torch\n", "from auffusion_pipeline import AuffusionPipeline\n", "from IPython.display import display, Audio" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "pretrained_model_name_or_path = \"auffusion/auffusion\"\n", "dtype = torch.float16\n", "device = \"cuda\"" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The config attributes {'decay': 0.9999, 'inv_gamma': 1.0, 'min_decay': 0.0, 'optimization_step': 100000, 'power': 0.6666666666666666, 'update_after_step': 0, 'use_ema_warmup': False} were passed to UNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Removing weight norm...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "LOADING CONDITION ENCODER 0\n", "INITIATED: ConditionAdapter: {'text_encoder_name': 'text_encoder_0', 'condition_adapter_name': 'condition_adapter_0', 'condition_type': 'clip-vit-large-patch14_text', 'pretrained_model_name_or_path': 'openai/clip-vit-large-patch14', 'condition_max_length': 77, 'condition_dim': 768, 'cross_attention_dim': 768}\n", "LOADED: ConditionAdapter from /home/xjl/Project/Audio/T2A/Auffusion/huggingface_checkpoint/auffusion/condition_adapter_0\n", "LOADING CONDITION ADAPTER 0\n" ] } ], "source": [ "pipeline = AuffusionPipeline.from_pretrained(pretrained_model_name_or_path)\n", "pipeline = pipeline.to(device, dtype)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2696f2d9449c4d939760561121333b07", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/50 [00:00\n", " \n", " Your browser does not support the audio element.\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "prompt = \"Birds singing sweetly in a blooming garden\"\n", "prompt = \"A kitten mewing for attention\"\n", "\n", "output = pipeline(prompt=prompt)\n", "audio = output.audios[0]\n", "display(Audio(audio, rate=16000))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "92ccaa24df8c4d138021301728a72ef8", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00\n", " \n", " Your browser does not support the audio element.\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "prompt = \"Birds singing sweetly in a blooming garden\"\n", "prompt = \"A kitten mewing for attention\"\n", "num_inference_steps = 100\n", "guidance_scale = 7.5\n", "seed = 42\n", "\n", "\n", "generator = torch.Generator(device=device).manual_seed(seed)\n", "\n", "with torch.autocast(\"cuda\"):\n", " output = pipeline(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator) \n", "\n", "audio = output.audios[0]\n", "display(Audio(audio, rate=16000))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "TTA", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }