Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			T4
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			T4
	Testing Seed Values
Browse filesAllow loading from file
- app.py +19 -7
- audiocraft/models/loaders.py +4 -0
- audiocraft/models/musicgen.py +5 -4
- audiocraft/utils/extend.py +22 -61
    	
        app.py
    CHANGED
    
    | @@ -15,6 +15,7 @@ from audiocraft.models import MusicGen | |
| 15 | 
             
            from audiocraft.data.audio import audio_write
         | 
| 16 | 
             
            from audiocraft.utils.extend import generate_music_segments, add_settings_to_image, sanitize_file_name
         | 
| 17 | 
             
            import numpy as np
         | 
|  | |
| 18 |  | 
| 19 | 
             
            MODEL = None
         | 
| 20 | 
             
            IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
         | 
| @@ -25,7 +26,7 @@ def load_model(version): | |
| 25 | 
             
                return MusicGen.get_pretrained(version)
         | 
| 26 |  | 
| 27 |  | 
| 28 | 
            -
            def predict(model, text, melody, duration, dimension, topk, topp, temperature, cfg_coef, background, title, include_settings, settings_font, settings_font_color):
         | 
| 29 | 
             
                global MODEL    
         | 
| 30 | 
             
                output_segments = None
         | 
| 31 | 
             
                topk = int(topk)
         | 
| @@ -36,6 +37,10 @@ def predict(model, text, melody, duration, dimension, topk, topp, temperature, c | |
| 36 | 
             
                    segment_duration = MODEL.lm.cfg.dataset.segment_duration
         | 
| 37 | 
             
                else:
         | 
| 38 | 
             
                    segment_duration = duration
         | 
|  | |
|  | |
|  | |
|  | |
| 39 | 
             
                MODEL.set_generation_params(
         | 
| 40 | 
             
                    use_sampling=True,
         | 
| 41 | 
             
                    top_k=topk,
         | 
| @@ -47,7 +52,7 @@ def predict(model, text, melody, duration, dimension, topk, topp, temperature, c | |
| 47 |  | 
| 48 | 
             
                if melody:
         | 
| 49 | 
             
                    if duration > MODEL.lm.cfg.dataset.segment_duration:
         | 
| 50 | 
            -
                        output_segments = generate_music_segments(text, melody, MODEL, duration, MODEL.lm.cfg.dataset.segment_duration)
         | 
| 51 | 
             
                    else:
         | 
| 52 | 
             
                        # pure original code
         | 
| 53 | 
             
                        sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
         | 
| @@ -76,14 +81,13 @@ def predict(model, text, melody, duration, dimension, topk, topp, temperature, c | |
| 76 | 
             
                    output = output.detach().cpu().float()[0]
         | 
| 77 | 
             
                with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
         | 
| 78 | 
             
                    if include_settings:
         | 
| 79 | 
            -
                        video_description = f"{text}\n Duration: {str(duration)} Dimension: {dimension}\n Top-k:{topk} Top-p:{topp}\n Randomness:{temperature}\n cfg:{cfg_coef}"
         | 
| 80 | 
             
                        background = add_settings_to_image(title, video_description, background_path=background, font=settings_font, font_color=settings_font_color)
         | 
| 81 | 
            -
                    #filename = sanitize_file_name(title) if title != "" else file.name
         | 
| 82 | 
             
                    audio_write(
         | 
| 83 | 
             
                        file.name, output, MODEL.sample_rate, strategy="loudness",
         | 
| 84 | 
             
                        loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
         | 
| 85 | 
             
                    waveform_video = gr.make_waveform(file.name,bg_image=background, bar_count=40)
         | 
| 86 | 
            -
                return waveform_video
         | 
| 87 |  | 
| 88 |  | 
| 89 | 
             
            def ui(**kwargs):
         | 
| @@ -121,15 +125,23 @@ def ui(**kwargs): | |
| 121 | 
             
                                model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
         | 
| 122 | 
             
                            with gr.Row():
         | 
| 123 | 
             
                                duration = gr.Slider(minimum=1, maximum=1000, value=10, label="Duration", interactive=True)
         | 
|  | |
| 124 | 
             
                                dimension = gr.Slider(minimum=-2, maximum=1, value=1, step=1, label="Dimension", info="determines which direction to add new segements of audio. (0 = stack tracks, 1 = lengthen, -1 = ?)", interactive=True)
         | 
| 125 | 
             
                            with gr.Row():
         | 
| 126 | 
             
                                topk = gr.Number(label="Top-k", value=250, interactive=True)
         | 
| 127 | 
             
                                topp = gr.Number(label="Top-p", value=0, interactive=True)
         | 
| 128 | 
             
                                temperature = gr.Number(label="Randomness Temperature", value=1.0, precision=2, interactive=True)
         | 
| 129 | 
             
                                cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, precision=2, interactive=True)
         | 
| 130 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 131 | 
             
                            output = gr.Video(label="Generated Music")
         | 
| 132 | 
            -
             | 
|  | |
|  | |
|  | |
| 133 | 
             
                    gr.Examples(
         | 
| 134 | 
             
                        fn=predict,
         | 
| 135 | 
             
                        examples=[
         | 
|  | |
| 15 | 
             
            from audiocraft.data.audio import audio_write
         | 
| 16 | 
             
            from audiocraft.utils.extend import generate_music_segments, add_settings_to_image, sanitize_file_name
         | 
| 17 | 
             
            import numpy as np
         | 
| 18 | 
            +
            import random
         | 
| 19 |  | 
| 20 | 
             
            MODEL = None
         | 
| 21 | 
             
            IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
         | 
|  | |
| 26 | 
             
                return MusicGen.get_pretrained(version)
         | 
| 27 |  | 
| 28 |  | 
| 29 | 
            +
            def predict(model, text, melody, duration, dimension, topk, topp, temperature, cfg_coef, background, title, include_settings, settings_font, settings_font_color, seed, overlap=1):
         | 
| 30 | 
             
                global MODEL    
         | 
| 31 | 
             
                output_segments = None
         | 
| 32 | 
             
                topk = int(topk)
         | 
|  | |
| 37 | 
             
                    segment_duration = MODEL.lm.cfg.dataset.segment_duration
         | 
| 38 | 
             
                else:
         | 
| 39 | 
             
                    segment_duration = duration
         | 
| 40 | 
            +
                # implement seed
         | 
| 41 | 
            +
                if seed < 0:
         | 
| 42 | 
            +
                    seed = random.randint(0, 0xffff_ffff_ffff)
         | 
| 43 | 
            +
                torch.manual_seed(seed)
         | 
| 44 | 
             
                MODEL.set_generation_params(
         | 
| 45 | 
             
                    use_sampling=True,
         | 
| 46 | 
             
                    top_k=topk,
         | 
|  | |
| 52 |  | 
| 53 | 
             
                if melody:
         | 
| 54 | 
             
                    if duration > MODEL.lm.cfg.dataset.segment_duration:
         | 
| 55 | 
            +
                        output_segments = generate_music_segments(text, melody, MODEL, seed, duration, overlap, MODEL.lm.cfg.dataset.segment_duration)
         | 
| 56 | 
             
                    else:
         | 
| 57 | 
             
                        # pure original code
         | 
| 58 | 
             
                        sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
         | 
|  | |
| 81 | 
             
                    output = output.detach().cpu().float()[0]
         | 
| 82 | 
             
                with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
         | 
| 83 | 
             
                    if include_settings:
         | 
| 84 | 
            +
                        video_description = f"{text}\n Duration: {str(duration)} Dimension: {dimension}\n Top-k:{topk} Top-p:{topp}\n Randomness:{temperature}\n cfg:{cfg_coef} overlap: {overlap}\n Seed: {seed}"
         | 
| 85 | 
             
                        background = add_settings_to_image(title, video_description, background_path=background, font=settings_font, font_color=settings_font_color)
         | 
|  | |
| 86 | 
             
                    audio_write(
         | 
| 87 | 
             
                        file.name, output, MODEL.sample_rate, strategy="loudness",
         | 
| 88 | 
             
                        loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
         | 
| 89 | 
             
                    waveform_video = gr.make_waveform(file.name,bg_image=background, bar_count=40)
         | 
| 90 | 
            +
                return waveform_video, seed
         | 
| 91 |  | 
| 92 |  | 
| 93 | 
             
            def ui(**kwargs):
         | 
|  | |
| 125 | 
             
                                model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
         | 
| 126 | 
             
                            with gr.Row():
         | 
| 127 | 
             
                                duration = gr.Slider(minimum=1, maximum=1000, value=10, label="Duration", interactive=True)
         | 
| 128 | 
            +
                                overlap = gr.Slider(minimum=1, maximum=29, value=5, step=1, label="Overlap", interactive=True)
         | 
| 129 | 
             
                                dimension = gr.Slider(minimum=-2, maximum=1, value=1, step=1, label="Dimension", info="determines which direction to add new segements of audio. (0 = stack tracks, 1 = lengthen, -1 = ?)", interactive=True)
         | 
| 130 | 
             
                            with gr.Row():
         | 
| 131 | 
             
                                topk = gr.Number(label="Top-k", value=250, interactive=True)
         | 
| 132 | 
             
                                topp = gr.Number(label="Top-p", value=0, interactive=True)
         | 
| 133 | 
             
                                temperature = gr.Number(label="Randomness Temperature", value=1.0, precision=2, interactive=True)
         | 
| 134 | 
             
                                cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, precision=2, interactive=True)
         | 
| 135 | 
            +
                            with gr.Row():
         | 
| 136 | 
            +
                                seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
         | 
| 137 | 
            +
                                gr.Button('\U0001f3b2\ufe0f').style(full_width=False).click(fn=lambda: -1, outputs=[seed], queue=False)
         | 
| 138 | 
            +
                                reuse_seed = gr.Button('\u267b\ufe0f').style(full_width=False)
         | 
| 139 | 
            +
                        with gr.Column() as c:
         | 
| 140 | 
             
                            output = gr.Video(label="Generated Music")
         | 
| 141 | 
            +
                            seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False)
         | 
| 144 | 
            +
                    submit.click(predict, inputs=[model, text, melody, duration, dimension, topk, topp, temperature, cfg_coef, background, title, include_settings, settings_font, settings_font_color, seed, overlap], outputs=[output, seed_used])
         | 
| 145 | 
             
                    gr.Examples(
         | 
| 146 | 
             
                        fn=predict,
         | 
| 147 | 
             
                        examples=[
         | 
    	
        audiocraft/models/loaders.py
    CHANGED
    
    | @@ -50,6 +50,10 @@ def _get_state_dict( | |
| 50 |  | 
| 51 | 
             
                if os.path.isfile(file_or_url_or_id):
         | 
| 52 | 
             
                    return torch.load(file_or_url_or_id, map_location=device)
         | 
|  | |
|  | |
|  | |
|  | |
| 53 |  | 
| 54 | 
             
                elif file_or_url_or_id.startswith('https://'):
         | 
| 55 | 
             
                    return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
         | 
|  | |
| 50 |  | 
| 51 | 
             
                if os.path.isfile(file_or_url_or_id):
         | 
| 52 | 
             
                    return torch.load(file_or_url_or_id, map_location=device)
         | 
| 53 | 
            +
                
         | 
| 54 | 
            +
                if os.path.isdir(file_or_url_or_id):
         | 
| 55 | 
            +
                    file = f"{file_or_url_or_id}/{filename}"
         | 
| 56 | 
            +
                    return torch.load(file, map_location=device)
         | 
| 57 |  | 
| 58 | 
             
                elif file_or_url_or_id.startswith('https://'):
         | 
| 59 | 
             
                    return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
         | 
    	
        audiocraft/models/musicgen.py
    CHANGED
    
    | @@ -80,10 +80,11 @@ class MusicGen: | |
| 80 | 
             
                        return MusicGen(name, compression_model, lm)
         | 
| 81 |  | 
| 82 | 
             
                    if name not in HF_MODEL_CHECKPOINTS_MAP:
         | 
| 83 | 
            -
                         | 
| 84 | 
            -
                             | 
| 85 | 
            -
             | 
| 86 | 
            -
             | 
|  | |
| 87 |  | 
| 88 | 
             
                    cache_dir = os.environ.get('MUSICGEN_ROOT', None)
         | 
| 89 | 
             
                    compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
         | 
|  | |
| 80 | 
             
                        return MusicGen(name, compression_model, lm)
         | 
| 81 |  | 
| 82 | 
             
                    if name not in HF_MODEL_CHECKPOINTS_MAP:
         | 
| 83 | 
            +
                        if not os.path.isfile(name) and not os.path.isdir(name):
         | 
| 84 | 
            +
                            raise ValueError(
         | 
| 85 | 
            +
                                f"{name} is not a valid checkpoint name. "
         | 
| 86 | 
            +
                                f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
         | 
| 87 | 
            +
                            )
         | 
| 88 |  | 
| 89 | 
             
                    cache_dir = os.environ.get('MUSICGEN_ROOT', None)
         | 
| 90 | 
             
                    compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
         | 
    	
        audiocraft/utils/extend.py
    CHANGED
    
    | @@ -8,29 +8,34 @@ import tempfile | |
| 8 | 
             
            import os
         | 
| 9 | 
             
            import textwrap
         | 
| 10 |  | 
| 11 | 
            -
            def separate_audio_segments(audio, segment_duration=30):
         | 
| 12 | 
             
                sr, audio_data = audio[0], audio[1]
         | 
| 13 | 
            -
             | 
| 14 | 
             
                total_samples = len(audio_data)
         | 
| 15 | 
             
                segment_samples = sr * segment_duration
         | 
| 16 | 
            -
                
         | 
| 17 | 
            -
             | 
| 18 | 
            -
                
         | 
| 19 | 
             
                segments = []
         | 
| 20 | 
            -
                
         | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
                     | 
| 24 | 
            -
                    end_sample = (segment_idx + 1) * segment_samples
         | 
| 25 | 
            -
                    
         | 
| 26 | 
             
                    segment = audio_data[start_sample:end_sample]
         | 
| 27 | 
             
                    segments.append((sr, segment))
         | 
| 28 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 29 | 
             
                return segments
         | 
| 30 |  | 
| 31 | 
            -
            def generate_music_segments(text, melody, MODEL, duration:int=10, segment_duration:int=30):
         | 
| 32 | 
             
                # generate audio segments
         | 
| 33 | 
            -
                melody_segments = separate_audio_segments(melody, segment_duration) 
         | 
| 34 |  | 
| 35 | 
             
                # Create a list to store the melody tensors for each segment
         | 
| 36 | 
             
                melodys = []
         | 
| @@ -40,7 +45,7 @@ def generate_music_segments(text, melody, MODEL, duration:int=10, segment_durati | |
| 40 | 
             
                total_segments = max(math.ceil(duration / segment_duration),1)
         | 
| 41 | 
             
                print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds")
         | 
| 42 |  | 
| 43 | 
            -
                # If melody_segments is shorter than total_segments, repeat the segments until the  | 
| 44 | 
             
                if len(melody_segments) < total_segments:
         | 
| 45 | 
             
                    for i in range(total_segments - len(melody_segments)):
         | 
| 46 | 
             
                        segment = melody_segments[i]
         | 
| @@ -59,6 +64,7 @@ def generate_music_segments(text, melody, MODEL, duration:int=10, segment_durati | |
| 59 | 
             
                    # Append the segment to the melodys list
         | 
| 60 | 
             
                    melodys.append(verse)
         | 
| 61 |  | 
|  | |
| 62 | 
             
                for idx, verse in enumerate(melodys):
         | 
| 63 | 
             
                    print(f"Generating New Melody Segment {idx + 1}: {text}\r")
         | 
| 64 | 
             
                    output = MODEL.generate_with_chroma(
         | 
| @@ -74,42 +80,6 @@ def generate_music_segments(text, melody, MODEL, duration:int=10, segment_durati | |
| 74 | 
             
                    print(f"output_segments: {len(output_segments)}: shape: {output.shape} dim {output.dim()}")
         | 
| 75 | 
             
                return output_segments
         | 
| 76 |  | 
| 77 | 
            -
            #def generate_music_segments(text, melody, duration, MODEL, segment_duration=30):
         | 
| 78 | 
            -
            #    sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
         | 
| 79 | 
            -
                
         | 
| 80 | 
            -
            #    # Create a list to store the melody tensors for each segment
         | 
| 81 | 
            -
            #    melodys = []
         | 
| 82 | 
            -
                
         | 
| 83 | 
            -
            #    # Calculate the total number of segments
         | 
| 84 | 
            -
            #    total_segments = math.ceil(melody.shape[1] / (sr * segment_duration))
         | 
| 85 | 
            -
             | 
| 86 | 
            -
            #    # Iterate over the segments
         | 
| 87 | 
            -
            #    for segment_idx in range(total_segments):
         | 
| 88 | 
            -
            #        print(f"segment {segment_idx + 1} / {total_segments + 1} \r")
         | 
| 89 | 
            -
            #        start_frame = segment_idx * sr * segment_duration
         | 
| 90 | 
            -
            #        end_frame = (segment_idx + 1) * sr * segment_duration
         | 
| 91 | 
            -
             | 
| 92 | 
            -
            #        # Extract the segment from the melody tensor
         | 
| 93 | 
            -
            #        segment = melody[:, start_frame:end_frame]
         | 
| 94 | 
            -
             | 
| 95 | 
            -
            #        # Append the segment to the melodys list
         | 
| 96 | 
            -
            #        melodys.append(segment)
         | 
| 97 | 
            -
             | 
| 98 | 
            -
            #    output_segments = []
         | 
| 99 | 
            -
             | 
| 100 | 
            -
            #    for segment in melodys:
         | 
| 101 | 
            -
            #        output = MODEL.generate_with_chroma(
         | 
| 102 | 
            -
            #            descriptions=[text],
         | 
| 103 | 
            -
            #            melody_wavs=segment,
         | 
| 104 | 
            -
            #            melody_sample_rate=sr,
         | 
| 105 | 
            -
            #            progress=False
         | 
| 106 | 
            -
            #        )
         | 
| 107 | 
            -
             | 
| 108 | 
            -
            #        # Append the generated output to the list of segments
         | 
| 109 | 
            -
            #        output_segments.append(output[:, :segment_duration])
         | 
| 110 | 
            -
             | 
| 111 | 
            -
            #    return output_segments
         | 
| 112 | 
            -
             | 
| 113 | 
             
            def save_image(image):
         | 
| 114 | 
             
                """
         | 
| 115 | 
             
                Saves a PIL image to a temporary file and returns the file path.
         | 
| @@ -184,13 +154,4 @@ def add_settings_to_image(title: str = "title", description: str = "", width: in | |
| 184 | 
             
                background.paste(image, offset, mask=image)
         | 
| 185 |  | 
| 186 | 
             
                # Save the image and return the file path
         | 
| 187 | 
            -
                return save_image(background)
         | 
| 188 | 
            -
             | 
| 189 | 
            -
             | 
| 190 | 
            -
            def sanitize_file_name(filename):
         | 
| 191 | 
            -
                valid_chars = "-_.() " + string.ascii_letters + string.digits
         | 
| 192 | 
            -
                sanitized_filename = ''.join(c for c in filename if c in valid_chars)
         | 
| 193 | 
            -
                return sanitized_filename
         | 
| 194 | 
            -
             | 
| 195 | 
            -
             | 
| 196 | 
            -
             | 
|  | |
| 8 | 
             
            import os
         | 
| 9 | 
             
            import textwrap
         | 
| 10 |  | 
| 11 | 
            +
            def separate_audio_segments(audio, segment_duration=30, overlap=1):
         | 
| 12 | 
             
                sr, audio_data = audio[0], audio[1]
         | 
| 13 | 
            +
             | 
| 14 | 
             
                total_samples = len(audio_data)
         | 
| 15 | 
             
                segment_samples = sr * segment_duration
         | 
| 16 | 
            +
                overlap_samples = sr * overlap
         | 
| 17 | 
            +
             | 
|  | |
| 18 | 
             
                segments = []
         | 
| 19 | 
            +
                start_sample = 0
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                while total_samples >= segment_samples:
         | 
| 22 | 
            +
                    end_sample = start_sample + segment_samples
         | 
|  | |
|  | |
| 23 | 
             
                    segment = audio_data[start_sample:end_sample]
         | 
| 24 | 
             
                    segments.append((sr, segment))
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    start_sample += segment_samples - overlap_samples
         | 
| 27 | 
            +
                    total_samples -= segment_samples - overlap_samples
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                # Collect the final segment
         | 
| 30 | 
            +
                if total_samples > 0:
         | 
| 31 | 
            +
                    segment = audio_data[-segment_samples:]
         | 
| 32 | 
            +
                    segments.append((sr, segment))
         | 
| 33 | 
            +
             | 
| 34 | 
             
                return segments
         | 
| 35 |  | 
| 36 | 
            +
            def generate_music_segments(text, melody, MODEL, seed, duration:int=10, overlap:int=1, segment_duration:int=30):
         | 
| 37 | 
             
                # generate audio segments
         | 
| 38 | 
            +
                melody_segments = separate_audio_segments(melody, segment_duration, overlap) 
         | 
| 39 |  | 
| 40 | 
             
                # Create a list to store the melody tensors for each segment
         | 
| 41 | 
             
                melodys = []
         | 
|  | |
| 45 | 
             
                total_segments = max(math.ceil(duration / segment_duration),1)
         | 
| 46 | 
             
                print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds")
         | 
| 47 |  | 
| 48 | 
            +
                # If melody_segments is shorter than total_segments, repeat the segments until the total_segments is reached
         | 
| 49 | 
             
                if len(melody_segments) < total_segments:
         | 
| 50 | 
             
                    for i in range(total_segments - len(melody_segments)):
         | 
| 51 | 
             
                        segment = melody_segments[i]
         | 
|  | |
| 64 | 
             
                    # Append the segment to the melodys list
         | 
| 65 | 
             
                    melodys.append(verse)
         | 
| 66 |  | 
| 67 | 
            +
                torch.manual_seed(seed)
         | 
| 68 | 
             
                for idx, verse in enumerate(melodys):
         | 
| 69 | 
             
                    print(f"Generating New Melody Segment {idx + 1}: {text}\r")
         | 
| 70 | 
             
                    output = MODEL.generate_with_chroma(
         | 
|  | |
| 80 | 
             
                    print(f"output_segments: {len(output_segments)}: shape: {output.shape} dim {output.dim()}")
         | 
| 81 | 
             
                return output_segments
         | 
| 82 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 83 | 
             
            def save_image(image):
         | 
| 84 | 
             
                """
         | 
| 85 | 
             
                Saves a PIL image to a temporary file and returns the file path.
         | 
|  | |
| 154 | 
             
                background.paste(image, offset, mask=image)
         | 
| 155 |  | 
| 156 | 
             
                # Save the image and return the file path
         | 
| 157 | 
            +
                return save_image(background)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
