import streamlit as st
import requests
from PIL import Image
from io import BytesIO
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import StreamingResponse
from tensorflow.keras.models import load_model
import numpy as np
import io
import warnings
# Suppress warnings
warnings.filterwarnings('ignore')
# Set Streamlit page configuration
st.set_page_config(
page_title="Sketch to Image using GAN",
layout="centered",
page_icon="🖌️",
)
# Title and description
st.markdown("
Sketch to Image using GAN 🖌️
", unsafe_allow_html=True)
st.markdown("Upload your sketch to generate an image!
", unsafe_allow_html=True)
# Upload file widget
uploaded_file = st.file_uploader("Upload a sketch (jpg, jpeg, png):", type=["jpg", "jpeg", "png"])
# Model Loading
try:
generator_model = load_model('model.h5') # Update this path to your actual model file
st.success("Model loaded successfully!")
except Exception as e:
st.error(f"Error loading the model: {str(e)}")
# Image processing function
def process_and_generate_image(image_data):
image = Image.open(io.BytesIO(image_data)).convert('RGB')
image = image.resize((256, 256))
# Preprocess image
image_array = np.array(image)
image_array = (image_array - 127.5) / 127.5 # Normalize to [-1, 1]
image_array = np.expand_dims(image_array, axis=0)
# Generate fake image
fake_image = generator_model.predict(image_array)
fake_image = (fake_image + 1) / 2.0 # Rescale to [0, 1]
fake_image = np.squeeze(fake_image)
fake_image = (fake_image * 255).astype(np.uint8)
return Image.fromarray(fake_image)
# Display uploaded image and handle generation
if uploaded_file is not None:
st.image(uploaded_file, caption="Uploaded Sketch", width=300)
if st.button("Generate Image"):
with st.spinner('Generating...'):
try:
# Generate the image
generated_image = process_and_generate_image(uploaded_file.getvalue())
# Display the generated image
st.image(generated_image, caption="Generated Image", width=300)
except Exception as e:
st.error(f"Error generating image: {str(e)}")
# FastAPI app for backend
app = FastAPI()
@app.post("/generate-image/")
async def generate_image(file: UploadFile = File(...)):
contents = await file.read()
generated_image = process_and_generate_image(contents)
img_io = io.BytesIO()
generated_image.save(img_io, 'JPEG')
img_io.seek(0)
return StreamingResponse(img_io, media_type="image/jpeg")
# Running FastAPI app if script is executed directly
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)