{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"../\")\n",
"\n",
"import torch\n",
"from auffusion_pipeline import AuffusionPipeline\n",
"from IPython.display import display, Audio"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"pretrained_model_name_or_path = \"auffusion/auffusion\"\n",
"dtype = torch.float16\n",
"device = \"cuda\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The config attributes {'decay': 0.9999, 'inv_gamma': 1.0, 'min_decay': 0.0, 'optimization_step': 100000, 'power': 0.6666666666666666, 'update_after_step': 0, 'use_ema_warmup': False} were passed to UNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Removing weight norm...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"LOADING CONDITION ENCODER 0\n",
"INITIATED: ConditionAdapter: {'text_encoder_name': 'text_encoder_0', 'condition_adapter_name': 'condition_adapter_0', 'condition_type': 'clip-vit-large-patch14_text', 'pretrained_model_name_or_path': 'openai/clip-vit-large-patch14', 'condition_max_length': 77, 'condition_dim': 768, 'cross_attention_dim': 768}\n",
"LOADED: ConditionAdapter from /home/xjl/Project/Audio/T2A/Auffusion/huggingface_checkpoint/auffusion/condition_adapter_0\n",
"LOADING CONDITION ADAPTER 0\n"
]
}
],
"source": [
"pipeline = AuffusionPipeline.from_pretrained(pretrained_model_name_or_path)\n",
"pipeline = pipeline.to(device, dtype)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2696f2d9449c4d939760561121333b07",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/50 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"prompt = \"Birds singing sweetly in a blooming garden\"\n",
"prompt = \"A kitten mewing for attention\"\n",
"\n",
"output = pipeline(prompt=prompt)\n",
"audio = output.audios[0]\n",
"display(Audio(audio, rate=16000))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "92ccaa24df8c4d138021301728a72ef8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/100 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"prompt = \"Birds singing sweetly in a blooming garden\"\n",
"prompt = \"A kitten mewing for attention\"\n",
"num_inference_steps = 100\n",
"guidance_scale = 7.5\n",
"seed = 42\n",
"\n",
"\n",
"generator = torch.Generator(device=device).manual_seed(seed)\n",
"\n",
"with torch.autocast(\"cuda\"):\n",
" output = pipeline(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator) \n",
"\n",
"audio = output.audios[0]\n",
"display(Audio(audio, rate=16000))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "TTA",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}