sumityadav329 commited on
Commit
ada9c6c
·
verified ·
1 Parent(s): f12470b

utils.py created

Browse files
Files changed (1) hide show
  1. utils.py +91 -0
utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import time
4
+ from typing import Optional
5
+
6
+ def load_environment():
7
+ """
8
+ Attempt to load environment variables with error handling.
9
+
10
+ Returns:
11
+ Optional[str]: Hugging Face Token or None
12
+ """
13
+ try:
14
+ from dotenv import load_dotenv
15
+ load_dotenv()
16
+ except ImportError:
17
+ print("python-dotenv not installed. Ensure HF_TOKEN is set in environment.")
18
+
19
+ return os.getenv("HF_TOKEN")
20
+
21
+ def query_hf_api(
22
+ prompt: str,
23
+ model_url: str = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0",
24
+ max_retries: int = 3
25
+ ) -> Optional[bytes]:
26
+ """
27
+ Query the Hugging Face Inference API with robust error handling and retry mechanism.
28
+
29
+ Args:
30
+ prompt (str): Text prompt for image generation
31
+ model_url (str): URL of the Hugging Face model
32
+ max_retries (int): Maximum number of retry attempts
33
+
34
+ Returns:
35
+ Optional[bytes]: Generated image bytes or None
36
+ """
37
+ # Validate inputs
38
+ if not prompt or not prompt.strip():
39
+ raise ValueError("Prompt cannot be empty")
40
+
41
+ # Load token
42
+ HF_TOKEN = load_environment()
43
+ if not HF_TOKEN:
44
+ raise ValueError("Hugging Face token not found. Set HF_TOKEN in .env or environment variables.")
45
+
46
+ # Prepare headers
47
+ headers = {
48
+ "Authorization": f"Bearer {HF_TOKEN}",
49
+ "Content-Type": "application/json"
50
+ }
51
+
52
+ # Payload with additional configuration
53
+ payload = {
54
+ "inputs": prompt,
55
+ "parameters": {
56
+ "negative_prompt": "low quality, bad anatomy, blurry",
57
+ "num_inference_steps": 50,
58
+ }
59
+ }
60
+
61
+ # Retry mechanism
62
+ for attempt in range(max_retries):
63
+ try:
64
+ response = requests.post(
65
+ model_url,
66
+ headers=headers,
67
+ json=payload,
68
+ timeout=120 # 2-minute timeout
69
+ )
70
+
71
+ # Check for specific error conditions
72
+ if response.status_code == 503:
73
+ # Model might be loading, wait and retry
74
+ print(f"Service unavailable, retrying in {5 * (attempt + 1)} seconds...")
75
+ time.sleep(5 * (attempt + 1))
76
+ continue
77
+
78
+ response.raise_for_status() # Raise exception for bad status codes
79
+
80
+ return response.content
81
+
82
+ except requests.exceptions.RequestException as e:
83
+ print(f"Request error (Attempt {attempt + 1}/{max_retries}): {e}")
84
+
85
+ if attempt == max_retries - 1:
86
+ raise RuntimeError(f"Failed to generate image after {max_retries} attempts: {e}")
87
+
88
+ # Wait before retrying
89
+ time.sleep(5 * (attempt + 1))
90
+
91
+ raise RuntimeError("Unexpected error in image generation")