Mikiko Bazeley commited on
Commit
f3198c2
·
1 Parent(s): b621941

Reproduced the fjord example

Browse files
edge.png ADDED
fjord.png ADDED
output_image.jpg ADDED
pages/{5_FLUX_image_generation.py → 5_Test_FLUX_image_generation.py} RENAMED
File without changes
pages/6_Test_Control_Net_Flux.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import requests
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os
7
+ from dotenv import load_dotenv
8
+ import streamlit as st
9
+
10
+ # Set the page title for the Streamlit app
11
+ st.set_page_config(page_title="ControlNet Image Generation")
12
+
13
+ # Correct the path to the .env file to reflect its location
14
+ dotenv_path = os.path.join(os.path.dirname(__file__), '../env/.env')
15
+
16
+ # Load environment variables from the .env file
17
+ load_dotenv(dotenv_path, override=True)
18
+
19
+ # Get the Fireworks API key from the .env file
20
+ api_key = os.getenv("FIREWORKS_API_KEY")
21
+
22
+ if not api_key:
23
+ st.error("API key not found. Make sure FIREWORKS_API_KEY is set in the .env file.")
24
+ st.stop()
25
+
26
+ # Load image and apply Canny edge detection
27
+ def process_image(uploaded_image):
28
+ # Convert the uploaded image into an OpenCV-compatible format (grayscale)
29
+ image = np.array(Image.open(uploaded_image).convert('L'))
30
+
31
+ # Apply Canny edge detection
32
+ edges = cv2.Canny(image, 100, 200)
33
+
34
+ # Convert the single-channel edges image to a 3-channel image (RGB)
35
+ edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
36
+
37
+ # Convert the edges image to a format that can be sent in the HTTP request
38
+ pil_image = Image.fromarray(edges_rgb)
39
+ byte_arr = BytesIO()
40
+ pil_image.save(byte_arr, format='JPEG')
41
+ byte_arr.seek(0)
42
+
43
+ return byte_arr, pil_image
44
+
45
+ # Function to make the POST request to the API
46
+ def call_control_net_api(uploaded_image, prompt, control_mode=0, aspect_ratio="16:9",
47
+ guidance_scale=3.5, num_inference_steps=30, seed=0, controlnet_conditioning_scale=1.0):
48
+ # Process the image for control net
49
+ control_image, processed_image = process_image(uploaded_image)
50
+
51
+ # Prepare the payload
52
+ files = {
53
+ 'control_image': ('control_image.jpg', control_image, 'image/jpeg')
54
+ }
55
+ data = {
56
+ 'prompt': prompt,
57
+ 'control_mode': control_mode, # Control how the control image is used
58
+ 'aspect_ratio': aspect_ratio,
59
+ 'guidance_scale': guidance_scale,
60
+ 'num_inference_steps': num_inference_steps,
61
+ 'seed': seed,
62
+ 'controlnet_conditioning_scale': controlnet_conditioning_scale # Control how strongly the control image influences
63
+ }
64
+
65
+ headers = {
66
+ 'accept': 'image/jpeg',
67
+ 'authorization': f'Bearer {api_key}', # Using the API key loaded from the .env file
68
+ }
69
+
70
+ # Send the POST request
71
+ response = requests.post('https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/flux-1-dev-controlnet-union/control_net',
72
+ files=files, data=data, headers=headers)
73
+
74
+ # Handle the response
75
+ if response.status_code == 200:
76
+ return Image.open(BytesIO(response.content)), processed_image
77
+ else:
78
+ st.error(f"Request failed with status code: {response.status_code}, Response: {response.text}")
79
+ return None, None
80
+
81
+ # Streamlit UI
82
+ st.title("ControlNet Image Generation with Fireworks")
83
+ st.write("Upload an image, provide a prompt, and let the model generate an image using Canny edge detection as input.")
84
+
85
+ # File uploader for image input
86
+ uploaded_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
87
+
88
+ # Text input for prompt
89
+ prompt = st.text_input("Enter your prompt")
90
+
91
+ # Expander to hide/show additional parameters
92
+ with st.expander("Advanced Parameters"):
93
+ # Slider for `control_mode`
94
+ control_mode = st.slider("Control Mode", min_value=0, max_value=2, value=0, help="Control how the edge image affects generation")
95
+
96
+ # Slider for `controlnet_conditioning_scale`
97
+ controlnet_conditioning_scale = st.slider("ControlNet Conditioning Scale", min_value=0.0, max_value=1.0, value=0.5, step=0.1,
98
+ help="Adjust how strongly the edge-detected image influences the output")
99
+
100
+ # Dropdown for `aspect_ratio`
101
+ aspect_ratio = st.selectbox("Aspect Ratio", options=["16:9", "1:1", "4:3", "3:2", "9:16"], index=0)
102
+
103
+ # Slider for `guidance_scale`
104
+ guidance_scale = st.slider("Guidance Scale", min_value=0.0, max_value=20.0, value=3.5, step=0.1,
105
+ help="Adjust how strongly the model adheres to the prompt")
106
+
107
+ # Slider for `num_inference_steps`
108
+ num_inference_steps = st.slider("Number of Inference Steps", min_value=1, max_value=100, value=30, step=1,
109
+ help="Number of steps to generate the image")
110
+
111
+ # Slider for `seed`
112
+ seed = st.slider("Random Seed", min_value=0, max_value=1000, value=0,
113
+ help="Set a seed for reproducibility (0 means random)")
114
+
115
+ # Button to submit
116
+ if st.button("Generate Image"):
117
+ if uploaded_image is None:
118
+ st.error("Please upload an image.")
119
+ elif not prompt.strip():
120
+ st.error("Please enter a prompt.")
121
+ else:
122
+ with st.spinner("Processing..."):
123
+ # Display the uploaded image first
124
+ st.subheader("Uploaded Image")
125
+ st.image(uploaded_image, caption="Original Uploaded Image", use_column_width=True)
126
+
127
+ # Call the ControlNet API
128
+ generated_image, processed_image = call_control_net_api(uploaded_image, prompt,
129
+ control_mode=control_mode,
130
+ aspect_ratio=aspect_ratio,
131
+ guidance_scale=guidance_scale,
132
+ num_inference_steps=num_inference_steps,
133
+ seed=seed,
134
+ controlnet_conditioning_scale=controlnet_conditioning_scale)
135
+
136
+ if generated_image:
137
+ # Hide the processed edge-detected image in an expander
138
+ with st.expander("Edge Detection Result (Input to ControlNet)"):
139
+ st.image(processed_image, caption="Processed Edge Detection Image", use_column_width=True)
140
+
141
+ # Display the generated image from the API
142
+ st.subheader("Generated Image")
143
+ st.image(generated_image, caption="Generated Image from ControlNet", use_column_width=True)
requirements.txt CHANGED
@@ -135,6 +135,7 @@ numpy==1.26.4
135
  oauthlib==3.2.2
136
  onnxruntime==1.19.2
137
  openai==1.47.0
 
138
  opentelemetry-api==1.27.0
139
  opentelemetry-exporter-otlp-proto-common==1.27.0
140
  opentelemetry-exporter-otlp-proto-grpc==1.27.0
 
135
  oauthlib==3.2.2
136
  onnxruntime==1.19.2
137
  openai==1.47.0
138
+ opencv-python==4.10.0.84
139
  opentelemetry-api==1.27.0
140
  opentelemetry-exporter-otlp-proto-common==1.27.0
141
  opentelemetry-exporter-otlp-proto-grpc==1.27.0
pages/test_endpoint.py → test_endpoint.py RENAMED
File without changes