Spaces:
Running
on
Zero
Running
on
Zero
| # utils/ai_generator.py | |
| import gradio as gr | |
| import os | |
| import time | |
| #from turtle import width # Added for implementing delays | |
| from torch import cuda | |
| import random | |
| from utils.ai_generator_diffusers_flux import generate_ai_image_local | |
| #from pathlib import Path | |
| from huggingface_hub import InferenceClient | |
| import requests | |
| import io | |
| from PIL import Image | |
| from tempfile import NamedTemporaryFile | |
| import utils.constants as constants | |
| def generate_image_from_text(text, model_name="flax-community/dalle-mini", image_width=768, image_height=512, progress=gr.Progress(track_tqdm=True)): | |
| # Initialize the InferenceClient | |
| client = InferenceClient() | |
| # Generate the image from the text | |
| response = client(text, model_name) | |
| # Get the image data | |
| image_data = response.content | |
| # Load the image from the data | |
| image = Image.open(io.BytesIO(image_data)) | |
| # Resize the image | |
| image = image.resize((image_width, image_height)) | |
| return image | |
| def generate_ai_image( | |
| map_option, | |
| prompt_textbox_value, | |
| neg_prompt_textbox_value, | |
| model, | |
| lora_weights=None, | |
| conditioned_image=None, | |
| pipeline = "FluxPipeline", | |
| width=912, | |
| height=512, | |
| strength=0.5, | |
| seed = 0, | |
| progress=gr.Progress(track_tqdm=True), | |
| *args, | |
| **kwargs | |
| ): | |
| if seed == 0: | |
| seed = random.randint(0, constants.MAX_SEED) | |
| if (cuda.is_available() and cuda.device_count() >= 1): # Check if a local GPU is available | |
| print("Local GPU available. Generating image locally.") | |
| if conditioned_image is not None: | |
| pipeline = "FluxImg2ImgPipeline" | |
| return generate_ai_image_local( | |
| map_option, | |
| prompt_textbox_value, | |
| neg_prompt_textbox_value, | |
| model, | |
| lora_weights=lora_weights, | |
| seed=seed, | |
| conditioned_image=conditioned_image, | |
| pipeline_name=pipeline, | |
| strength=strength, | |
| height=height, | |
| width=width | |
| ) | |
| else: | |
| print("No local GPU available. Sending request to Hugging Face API.") | |
| return generate_ai_image_remote( | |
| map_option, | |
| prompt_textbox_value, | |
| neg_prompt_textbox_value, | |
| model, | |
| height=height, | |
| width=width, | |
| seed=seed | |
| ) | |
| def generate_ai_image_remote(map_option, prompt_textbox_value, neg_prompt_textbox_value, model, height=512, width=912, num_inference_steps=30, guidance_scale=3.5, seed=777,progress=gr.Progress(track_tqdm=True)): | |
| max_retries = 3 | |
| retry_delay = 4 # Initial delay in seconds | |
| try: | |
| if map_option != "Prompt": | |
| prompt = constants.PROMPTS[map_option] | |
| # Convert the negative prompt string to a list | |
| negative_prompt_str = constants.NEGATIVE_PROMPTS.get(map_option, "") | |
| negative_prompt = [p.strip() for p in negative_prompt_str.split(',') if p.strip()] | |
| else: | |
| prompt = prompt_textbox_value | |
| # Convert the negative prompt string to a list | |
| negative_prompt = [p.strip() for p in neg_prompt_textbox_value.split(',') if p.strip()] if neg_prompt_textbox_value else [] | |
| print("Remotely Generating image with the following parameters:") | |
| print(f"Prompt: {prompt}") | |
| print(f"Negative Prompt: {negative_prompt}") | |
| print(f"Height: {height}") | |
| print(f"Width: {width}") | |
| print(f"Number of Inference Steps: {num_inference_steps}") | |
| print(f"Guidance Scale: {guidance_scale}") | |
| print(f"Seed: {seed}") | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| if os.getenv("IS_SHARED_SPACE") == "True": | |
| client = InferenceClient( | |
| model, | |
| token=constants.HF_API_TOKEN | |
| ) | |
| image = client.text_to_image( | |
| inputs=prompt, | |
| parameters={ | |
| "guidance_scale": guidance_scale, | |
| "num_inference_steps": num_inference_steps, | |
| "width": width, | |
| "height": height, | |
| "max_sequence_length":512, | |
| # Optional: Add 'scheduler' and 'seed' if needed | |
| "seed": seed | |
| } | |
| ) | |
| else: | |
| API_URL = f"https://api-inference.huggingface.co/models/{model}" | |
| headers = { | |
| "Authorization": f"Bearer {constants.HF_API_TOKEN}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "guidance_scale": guidance_scale, | |
| "num_inference_steps": num_inference_steps, | |
| "width": width, | |
| "height": height, | |
| "max_sequence_length":512, | |
| # Optional: Add 'scheduler' and 'seed' if needed | |
| "seed": seed | |
| } | |
| } | |
| print(f"Attempt {attempt}: Sending POST request to Hugging Face API...") | |
| response = requests.post(API_URL, headers=headers, json=payload, timeout=300) # Increased timeout to 30 seconds | |
| if response.status_code == 200: | |
| image_bytes = response.content | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| break # Exit the retry loop on success | |
| elif response.status_code == 400: | |
| # Handle 400 Bad Request specifically | |
| print(f"Bad Request (400): {response.text}") | |
| print("Check your request parameters and payload format.") | |
| return None # Do not retry on 400 errors | |
| elif response.status_code in [429, 504]: | |
| print(f"Received status code {response.status_code}. Retrying in {retry_delay} seconds...") | |
| if attempt < max_retries: | |
| time.sleep(retry_delay) | |
| retry_delay *= 2 # Exponential backoff | |
| else: | |
| response.raise_for_status() # Raise exception after max retries | |
| else: | |
| print(f"Received unexpected status code {response.status_code}: {response.text}") | |
| response.raise_for_status() | |
| except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout) as timeout_error: | |
| print(f"Timeout occurred: {timeout_error}. Retrying in {retry_delay} seconds...") | |
| if attempt < max_retries: | |
| time.sleep(retry_delay) | |
| retry_delay *= 2 # Exponential backoff | |
| else: | |
| raise # Re-raise the exception after max retries | |
| except requests.exceptions.RequestException as req_error: | |
| print(f"Request exception: {req_error}. Retrying in {retry_delay} seconds...") | |
| if attempt < max_retries: | |
| time.sleep(retry_delay) | |
| retry_delay *= 2 # Exponential backoff | |
| else: | |
| raise # Re-raise the exception after max retries | |
| else: | |
| # If all retries failed | |
| print("Max retries exceeded. Failed to generate image.") | |
| return None | |
| with NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
| image.save(tmp.name, format="PNG") | |
| constants.temp_files.append(tmp.name) | |
| print(f"Image saved to {tmp.name}") | |
| return tmp.name | |
| except Exception as e: | |
| print(f"Error generating AI image: {e}") | |
| return None |