jiuface commited on
Commit
91d2029
1 Parent(s): ea8efbb

upload img to s3

Browse files
Files changed (2) hide show
  1. app.py +41 -3
  2. requirements.txt +2 -1
app.py CHANGED
@@ -17,6 +17,8 @@ from transformers import DPTFeatureExtractor, DPTForDepthEstimation, DPTImagePro
17
  from transformers import CLIPImageProcessor
18
  from diffusers.utils import load_image
19
  from gradio_imageslider import ImageSlider
 
 
20
 
21
  device = "cuda"
22
  base_model_id = "SG161222/RealVisXL_V4.0"
@@ -82,9 +84,25 @@ def get_depth_map(image):
82
  return image
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  @spaces.GPU(enable_queue=True)
87
- def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, control_strength, seed, progress=gr.Progress()):
88
 
89
  if image_url:
90
  orginal_image = load_image(image_url)
@@ -106,7 +124,14 @@ def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, contr
106
  generator=generator,
107
  image=depth_image
108
  ).images[0]
109
- return [[depth_image, generated_image], "ok"]
 
 
 
 
 
 
 
110
 
111
  with gr.Blocks() as demo:
112
 
@@ -128,6 +153,14 @@ with gr.Blocks() as demo:
128
  label="Negative prompt",
129
  value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
130
  )
 
 
 
 
 
 
 
 
131
  with gr.Column():
132
  result = ImageSlider(label="Generate image", type="pil", slider_color="pink")
133
  logs = gr.Textbox(label="logs")
@@ -140,7 +173,12 @@ with gr.Blocks() as demo:
140
  num_steps,
141
  guidance_scale,
142
  control_strength,
143
- seed
 
 
 
 
 
144
  ]
145
  run_button.click(
146
  fn=randomize_seed_fn,
 
17
  from transformers import CLIPImageProcessor
18
  from diffusers.utils import load_image
19
  from gradio_imageslider import ImageSlider
20
+ import boto3
21
+ from io import BytesIO
22
 
23
  device = "cuda"
24
  base_model_id = "SG161222/RealVisXL_V4.0"
 
84
  return image
85
 
86
 
87
+ def upload_to_s3(image, region, access_key, secret_key, bucket_name):
88
+ s3 = boto3.client(
89
+ 's3',
90
+ region_name=region,
91
+ aws_access_key_id=access_key,
92
+ aws_secret_access_key=secret_key
93
+ )
94
+ image_key = f"generated_images/{random.randint(0, MAX_SEED)}.png"
95
+ buffer = BytesIO()
96
+ image.save(buffer, "PNG")
97
+ buffer.seek(0)
98
+
99
+ s3.upload_fileobj(buffer, bucket_name, image_key)
100
+ return image_key
101
+
102
+
103
 
104
  @spaces.GPU(enable_queue=True)
105
+ def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, control_strength, seed, upload_to_s3, region, access_key, secret_key, progress=gr.Progress()):
106
 
107
  if image_url:
108
  orginal_image = load_image(image_url)
 
124
  generator=generator,
125
  image=depth_image
126
  ).images[0]
127
+
128
+ if upload_to_s3:
129
+ url = upload_to_s3(generated_image, region, access_key, secret_key, bucket)
130
+ result = {"status": "success", "url": url}
131
+ else:
132
+ result = {"status": "success", "message": "Image generated but not uploaded"}
133
+
134
+ return [[depth_image, generated_image], json.dumps(result)]
135
 
136
  with gr.Blocks() as demo:
137
 
 
153
  label="Negative prompt",
154
  value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
155
  )
156
+
157
+ upload_to_s3 = gr.Checkbox(label="Upload to S3", value=False)
158
+ region = gr.Textbox(label="S3 Region", placeholder="Enter S3 region here")
159
+ access_key = gr.Textbox(label="Access Key", placeholder="Enter S3 access key here")
160
+ secret_key = gr.Textbox(label="Secret Key", placeholder="Enter S3 secret key here")
161
+ bucket = gr.Textbox(label="Bucket Name", placeholder="Enter S3 bucket name here")
162
+
163
+
164
  with gr.Column():
165
  result = ImageSlider(label="Generate image", type="pil", slider_color="pink")
166
  logs = gr.Textbox(label="logs")
 
173
  num_steps,
174
  guidance_scale,
175
  control_strength,
176
+ seed,
177
+ upload_to_s3,
178
+ region,
179
+ access_key,
180
+ secret_key,
181
+ bucket
182
  ]
183
  run_button.click(
184
  fn=randomize_seed_fn,
requirements.txt CHANGED
@@ -9,4 +9,5 @@ requests
9
  spaces
10
  huggingface_hub
11
  controlnet-aux
12
- safetensors
 
 
9
  spaces
10
  huggingface_hub
11
  controlnet-aux
12
+ safetensors
13
+ boto3