Mikiko Bazeley
commited on
Commit
·
f3198c2
1
Parent(s):
b621941
Reproduced the fjord example
Browse files- edge.png +0 -0
- fjord.png +0 -0
- output_image.jpg +0 -0
- pages/{5_FLUX_image_generation.py → 5_Test_FLUX_image_generation.py} +0 -0
- pages/6_Test_Control_Net_Flux.py +143 -0
- requirements.txt +1 -0
- pages/test_endpoint.py → test_endpoint.py +0 -0
edge.png
ADDED
![]() |
fjord.png
ADDED
![]() |
output_image.jpg
ADDED
![]() |
pages/{5_FLUX_image_generation.py → 5_Test_FLUX_image_generation.py}
RENAMED
File without changes
|
pages/6_Test_Control_Net_Flux.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import requests
|
3 |
+
from io import BytesIO
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
import streamlit as st
|
9 |
+
|
10 |
+
# Set the page title for the Streamlit app
|
11 |
+
st.set_page_config(page_title="ControlNet Image Generation")
|
12 |
+
|
13 |
+
# Correct the path to the .env file to reflect its location
|
14 |
+
dotenv_path = os.path.join(os.path.dirname(__file__), '../env/.env')
|
15 |
+
|
16 |
+
# Load environment variables from the .env file
|
17 |
+
load_dotenv(dotenv_path, override=True)
|
18 |
+
|
19 |
+
# Get the Fireworks API key from the .env file
|
20 |
+
api_key = os.getenv("FIREWORKS_API_KEY")
|
21 |
+
|
22 |
+
if not api_key:
|
23 |
+
st.error("API key not found. Make sure FIREWORKS_API_KEY is set in the .env file.")
|
24 |
+
st.stop()
|
25 |
+
|
26 |
+
# Load image and apply Canny edge detection
|
27 |
+
def process_image(uploaded_image):
|
28 |
+
# Convert the uploaded image into an OpenCV-compatible format (grayscale)
|
29 |
+
image = np.array(Image.open(uploaded_image).convert('L'))
|
30 |
+
|
31 |
+
# Apply Canny edge detection
|
32 |
+
edges = cv2.Canny(image, 100, 200)
|
33 |
+
|
34 |
+
# Convert the single-channel edges image to a 3-channel image (RGB)
|
35 |
+
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
|
36 |
+
|
37 |
+
# Convert the edges image to a format that can be sent in the HTTP request
|
38 |
+
pil_image = Image.fromarray(edges_rgb)
|
39 |
+
byte_arr = BytesIO()
|
40 |
+
pil_image.save(byte_arr, format='JPEG')
|
41 |
+
byte_arr.seek(0)
|
42 |
+
|
43 |
+
return byte_arr, pil_image
|
44 |
+
|
45 |
+
# Function to make the POST request to the API
|
46 |
+
def call_control_net_api(uploaded_image, prompt, control_mode=0, aspect_ratio="16:9",
|
47 |
+
guidance_scale=3.5, num_inference_steps=30, seed=0, controlnet_conditioning_scale=1.0):
|
48 |
+
# Process the image for control net
|
49 |
+
control_image, processed_image = process_image(uploaded_image)
|
50 |
+
|
51 |
+
# Prepare the payload
|
52 |
+
files = {
|
53 |
+
'control_image': ('control_image.jpg', control_image, 'image/jpeg')
|
54 |
+
}
|
55 |
+
data = {
|
56 |
+
'prompt': prompt,
|
57 |
+
'control_mode': control_mode, # Control how the control image is used
|
58 |
+
'aspect_ratio': aspect_ratio,
|
59 |
+
'guidance_scale': guidance_scale,
|
60 |
+
'num_inference_steps': num_inference_steps,
|
61 |
+
'seed': seed,
|
62 |
+
'controlnet_conditioning_scale': controlnet_conditioning_scale # Control how strongly the control image influences
|
63 |
+
}
|
64 |
+
|
65 |
+
headers = {
|
66 |
+
'accept': 'image/jpeg',
|
67 |
+
'authorization': f'Bearer {api_key}', # Using the API key loaded from the .env file
|
68 |
+
}
|
69 |
+
|
70 |
+
# Send the POST request
|
71 |
+
response = requests.post('https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/flux-1-dev-controlnet-union/control_net',
|
72 |
+
files=files, data=data, headers=headers)
|
73 |
+
|
74 |
+
# Handle the response
|
75 |
+
if response.status_code == 200:
|
76 |
+
return Image.open(BytesIO(response.content)), processed_image
|
77 |
+
else:
|
78 |
+
st.error(f"Request failed with status code: {response.status_code}, Response: {response.text}")
|
79 |
+
return None, None
|
80 |
+
|
81 |
+
# Streamlit UI
|
82 |
+
st.title("ControlNet Image Generation with Fireworks")
|
83 |
+
st.write("Upload an image, provide a prompt, and let the model generate an image using Canny edge detection as input.")
|
84 |
+
|
85 |
+
# File uploader for image input
|
86 |
+
uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
|
87 |
+
|
88 |
+
# Text input for prompt
|
89 |
+
prompt = st.text_input("Enter your prompt")
|
90 |
+
|
91 |
+
# Expander to hide/show additional parameters
|
92 |
+
with st.expander("Advanced Parameters"):
|
93 |
+
# Slider for `control_mode`
|
94 |
+
control_mode = st.slider("Control Mode", min_value=0, max_value=2, value=0, help="Control how the edge image affects generation")
|
95 |
+
|
96 |
+
# Slider for `controlnet_conditioning_scale`
|
97 |
+
controlnet_conditioning_scale = st.slider("ControlNet Conditioning Scale", min_value=0.0, max_value=1.0, value=0.5, step=0.1,
|
98 |
+
help="Adjust how strongly the edge-detected image influences the output")
|
99 |
+
|
100 |
+
# Dropdown for `aspect_ratio`
|
101 |
+
aspect_ratio = st.selectbox("Aspect Ratio", options=["16:9", "1:1", "4:3", "3:2", "9:16"], index=0)
|
102 |
+
|
103 |
+
# Slider for `guidance_scale`
|
104 |
+
guidance_scale = st.slider("Guidance Scale", min_value=0.0, max_value=20.0, value=3.5, step=0.1,
|
105 |
+
help="Adjust how strongly the model adheres to the prompt")
|
106 |
+
|
107 |
+
# Slider for `num_inference_steps`
|
108 |
+
num_inference_steps = st.slider("Number of Inference Steps", min_value=1, max_value=100, value=30, step=1,
|
109 |
+
help="Number of steps to generate the image")
|
110 |
+
|
111 |
+
# Slider for `seed`
|
112 |
+
seed = st.slider("Random Seed", min_value=0, max_value=1000, value=0,
|
113 |
+
help="Set a seed for reproducibility (0 means random)")
|
114 |
+
|
115 |
+
# Button to submit
|
116 |
+
if st.button("Generate Image"):
|
117 |
+
if uploaded_image is None:
|
118 |
+
st.error("Please upload an image.")
|
119 |
+
elif not prompt.strip():
|
120 |
+
st.error("Please enter a prompt.")
|
121 |
+
else:
|
122 |
+
with st.spinner("Processing..."):
|
123 |
+
# Display the uploaded image first
|
124 |
+
st.subheader("Uploaded Image")
|
125 |
+
st.image(uploaded_image, caption="Original Uploaded Image", use_column_width=True)
|
126 |
+
|
127 |
+
# Call the ControlNet API
|
128 |
+
generated_image, processed_image = call_control_net_api(uploaded_image, prompt,
|
129 |
+
control_mode=control_mode,
|
130 |
+
aspect_ratio=aspect_ratio,
|
131 |
+
guidance_scale=guidance_scale,
|
132 |
+
num_inference_steps=num_inference_steps,
|
133 |
+
seed=seed,
|
134 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale)
|
135 |
+
|
136 |
+
if generated_image:
|
137 |
+
# Hide the processed edge-detected image in an expander
|
138 |
+
with st.expander("Edge Detection Result (Input to ControlNet)"):
|
139 |
+
st.image(processed_image, caption="Processed Edge Detection Image", use_column_width=True)
|
140 |
+
|
141 |
+
# Display the generated image from the API
|
142 |
+
st.subheader("Generated Image")
|
143 |
+
st.image(generated_image, caption="Generated Image from ControlNet", use_column_width=True)
|
requirements.txt
CHANGED
@@ -135,6 +135,7 @@ numpy==1.26.4
|
|
135 |
oauthlib==3.2.2
|
136 |
onnxruntime==1.19.2
|
137 |
openai==1.47.0
|
|
|
138 |
opentelemetry-api==1.27.0
|
139 |
opentelemetry-exporter-otlp-proto-common==1.27.0
|
140 |
opentelemetry-exporter-otlp-proto-grpc==1.27.0
|
|
|
135 |
oauthlib==3.2.2
|
136 |
onnxruntime==1.19.2
|
137 |
openai==1.47.0
|
138 |
+
opencv-python==4.10.0.84
|
139 |
opentelemetry-api==1.27.0
|
140 |
opentelemetry-exporter-otlp-proto-common==1.27.0
|
141 |
opentelemetry-exporter-otlp-proto-grpc==1.27.0
|
pages/test_endpoint.py → test_endpoint.py
RENAMED
File without changes
|