Wiuhh commited on
Commit
42eac23
·
verified ·
1 Parent(s): aeff9f9

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +38 -21
src/streamlit_app.py CHANGED
@@ -16,12 +16,12 @@ API_URL_DALLE = "https://api-inference.huggingface.co/models/ehristoforu/dalle-3
16
  headers = {"Authorization": f"Bearer {hf_token}"}
17
 
18
  # Function to query Hugging Face API
19
- def query(payload):
20
- response = requests.post(API_URL, headers=headers, json=payload)
21
  return response.content
22
 
23
  # Streamlit UI
24
- st.title("Text To Image models")
25
  st.write("Choose model and enter a prompt")
26
 
27
  model = st.selectbox(
@@ -31,21 +31,38 @@ model = st.selectbox(
31
 
32
  prompt = st.text_input("Enter prompt")
33
 
34
- # Generate Image
35
- if prompt:
36
- if model == "KVIImageR2.0":
37
- API_URL = API_URL_KVI
38
- elif model == "Midjourney V6":
39
- API_URL = API_URL_MJ
40
- elif model == "Dalle 3":
41
- API_URL = API_URL_DALLE
42
-
43
- image_bytes = query({"inputs": prompt})
44
-
45
- try:
46
- image = Image.open(io.BytesIO(image_bytes))
47
- st.image(image, caption="Generated Image")
48
- st.info("Image generated successfully!")
49
- except Exception as e:
50
- st.error("Failed to generate image. Please try again.")
51
- st.text(str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  headers = {"Authorization": f"Bearer {hf_token}"}
17
 
18
  # Function to query Hugging Face API
19
+ def query(payload, api_url):
20
+ response = requests.post(api_url, headers=headers, json=payload)
21
  return response.content
22
 
23
  # Streamlit UI
24
+ st.title("Text To Image Models")
25
  st.write("Choose model and enter a prompt")
26
 
27
  model = st.selectbox(
 
31
 
32
  prompt = st.text_input("Enter prompt")
33
 
34
+ # Button for generating image
35
+ if st.button("Generate Image"):
36
+ if prompt:
37
+ if model == "KVIImageR2.0":
38
+ API_URL = API_URL_KVI
39
+ elif model == "Midjourney V6":
40
+ API_URL = API_URL_MJ
41
+ elif model == "Dalle 3":
42
+ API_URL = API_URL_DALLE
43
+
44
+ with st.spinner("Generating image... Please wait."):
45
+ image_bytes = query({"inputs": prompt}, API_URL)
46
+
47
+ try:
48
+ image = Image.open(io.BytesIO(image_bytes))
49
+
50
+ # Image preview
51
+ st.image(image, caption="Generated Image Preview", use_column_width=True)
52
+
53
+ # Download option
54
+ img_buffer = io.BytesIO()
55
+ image.save(img_buffer, format="PNG")
56
+ st.download_button(
57
+ label="Download Image",
58
+ data=img_buffer.getvalue(),
59
+ file_name="generated_image.png",
60
+ mime="image/png"
61
+ )
62
+
63
+ st.success("Image generated successfully!")
64
+ except Exception as e:
65
+ st.error("Failed to generate image. Please try again.")
66
+ st.text(str(e))
67
+ else:
68
+ st.warning("Please enter a prompt before generating.")