Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| from num2words import num2words | |
| import numpy as np | |
| import os | |
| import random | |
| import re | |
| import textwrap | |
| import torch | |
| from shapely.geometry.polygon import Polygon | |
| from shapely.affinity import scale | |
| import aggdraw | |
| from PIL import Image, ImageDraw, ImageOps, ImageFilter, ImageFont, ImageColor | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM | |
| finetuned = AutoModelForCausalLM.from_pretrained('model') | |
| tokenizer = AutoTokenizer.from_pretrained('gpt2') | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| print(device) | |
| finetuned = finetuned.to(device) | |
| # Utility functions | |
| def containsNumber(value): | |
| for character in value: | |
| if character.isdigit(): | |
| return True | |
| return False | |
| def creativity(intensity): | |
| if(intensity == 'Low'): | |
| top_p = 0.95 | |
| top_k = 10 | |
| elif(intensity == 'Medium'): | |
| top_p = 0.9 | |
| top_k = 50 | |
| if(intensity == 'High'): | |
| top_p = 0.85 | |
| top_k = 100 | |
| return top_p, top_k | |
| housegan_labels = {"living_room": 1, "kitchen": 2, "bedroom": 3, "bathroom": 4, "missing": 5, "closet": 6, | |
| "balcony": 7, "corridor": 8, "dining_room": 9, "laundry_room": 10} | |
| architext_colors = [[0, 0, 0], [249, 222, 182], [195, 209, 217], [250, 120, 128], [126, 202, 234], [190, 0, 198], [255, 255, 255], | |
| [6, 53, 17], [17, 33, 58], [132, 151, 246], [197, 203, 159], [6, 53, 17],] | |
| regex = re.compile(".*?\((.*?)\)") | |
| def draw_polygons(polygons, colors, im_size=(512, 512), b_color="white", fpath=None): | |
| image = Image.new("RGBA", im_size, color="white") | |
| draw = aggdraw.Draw(image) | |
| for poly, color, in zip(polygons, colors): | |
| #get initial polygon coordinates | |
| xy = poly.exterior.xy | |
| coords = np.dstack((xy[1], xy[0])).flatten() | |
| # draw it on canvas, with the appropriate colors | |
| brush = aggdraw.Brush((0, 0, 0), opacity=255) | |
| draw.polygon(coords, brush) | |
| #get inner polygon coordinates | |
| small_poly = poly.buffer(-1, resolution=32, cap_style=2, join_style=2, mitre_limit=5.0) | |
| if small_poly.geom_type == 'MultiPolygon': | |
| mycoordslist = [list(x.exterior.coords) for x in small_poly] | |
| for coord in mycoordslist: | |
| coords = np.dstack((np.array(coord)[:,1], np.array(coord)[:, 0])).flatten() | |
| brush2 = aggdraw.Brush((0, 0, 0), opacity=255) | |
| draw.polygon(coords, brush2) | |
| elif poly.geom_type == 'Polygon': | |
| #get inner polygon coordinates | |
| xy2 = small_poly.exterior.xy | |
| coords2 = np.dstack((xy2[1], xy2[0])).flatten() | |
| # draw it on canvas, with the appropriate colors | |
| brush2 = aggdraw.Brush((color[0], color[1], color[2]), opacity=255) | |
| draw.polygon(coords2, brush2) | |
| image = Image.frombytes("RGBA", im_size, draw.tobytes()).transpose(Image.FLIP_TOP_BOTTOM) | |
| if(fpath): | |
| image.save(fpath, quality=100, subsampling=0) | |
| return draw, image | |
| def prompt_to_layout(user_prompt, intensity, fpath=None): | |
| if(containsNumber(user_prompt) == True): | |
| spaced_prompt = user_prompt.split(' ') | |
| new_prompt = ' '.join([word if word.isdigit() == False else num2words(int(word)).lower() for word in spaced_prompt]) | |
| model_prompt = '[User prompt] {} [Layout]'.format(new_prompt) | |
| top_p, top_k = creativity(intensity) | |
| model_prompt = '[User prompt] {} [Layout]'.format(user_prompt) | |
| input_ids = tokenizer(model_prompt, return_tensors='pt').to(device) | |
| output = finetuned.generate(**input_ids, do_sample=True, top_p=top_p, top_k=top_k, | |
| eos_token_id=50256, max_length=400) | |
| output = tokenizer.batch_decode(output, skip_special_tokens=True) | |
| layout = output[0].split('[User prompt]')[1].split('[Layout] ')[1].split(', ') | |
| spaces = [txt.split(':')[0] for txt in layout] | |
| coordinates = [txt.split(':')[1] for txt in layout] | |
| coordinates = [re.findall(regex, coord) for coord in coordinates] | |
| polygons = [] | |
| for coord in coordinates: | |
| polygons.append([point.split(',') for point in coord]) | |
| geom = [] | |
| for poly in polygons: | |
| scaled_poly = scale(Polygon(np.array(poly, dtype=int)), xfact=2, yfact=2, origin=(0,0)) | |
| geom.append(scaled_poly) | |
| #geom.append(Polygon(np.array(poly, dtype=int))) | |
| colors = [architext_colors[housegan_labels[space]] for space in spaces] | |
| _, im = draw_polygons(geom, colors, fpath=fpath) | |
| html = '<img class="labels" src="images/labels.png" />' | |
| legend = Image.open("labels.png") | |
| imgs_comb = np.vstack([legend, im]) | |
| imgs_comb = Image.fromarray(imgs_comb) | |
| return imgs_comb | |
| # Gradio App | |
| custom_css=""" | |
| @import url("https://use.typekit.net/nid3pfr.css"); | |
| .gradio_page { | |
| display: flex; | |
| width: 100vw; | |
| min-height: 50vh; | |
| flex-direction: column; | |
| justify-content: center; | |
| align-items: center; | |
| margin: 0px; | |
| max-width: 100vw; | |
| background: #FFFFFF; | |
| } | |
| .gradio_interface { | |
| width: 100vw; | |
| max-width: 1500px; | |
| } | |
| .gradio_page[theme=default] .panel_buttons { | |
| justify-content: flex-end; | |
| } | |
| .gradio_page[theme=default] .panel_button { | |
| flex: 0 0 0; | |
| min-width: 150px; | |
| } | |
| .gradio_page[theme=default] .gradio_interface .panel_button.submit { | |
| background: #11213A; | |
| border-radius: 5px; | |
| color: #FFFFFF; | |
| text-transform: uppercase; | |
| min-width: 150px; | |
| height: 4em; | |
| letter-spacing: 0.15em; | |
| flex: 0 0 0; | |
| } | |
| .gradio_page[theme=default] .gradio_interface .panel_button.submit:hover { | |
| background: #000000; | |
| } | |
| .input_text:focus { | |
| border-color: #FA7880; | |
| } | |
| .gradio_page[theme=default] .gradio_interface .input_text input, | |
| .gradio_page[theme=default] .gradio_interface .input_text textarea { | |
| font: 200 45px garamond-premier-pro-display, serif; | |
| line-height: 110%; | |
| color: #11213A; | |
| border-radius: 5px; | |
| padding: 15px; | |
| border: none; | |
| background: #F2F4F4; | |
| } | |
| .input_text textarea:focus-visible { | |
| outline: none; | |
| } | |
| .gradio_page[theme=default] .gradio_interface .input_radio .radio_item.selected { | |
| background-color: #11213A; | |
| } | |
| .gradio_page[theme=default] .gradio_interface .input_radio .selected .radio_circle { | |
| border-color: #4365c4; | |
| } | |
| .gradio_page[theme=default] .gradio_interface .output_image { | |
| width: 100%; | |
| height: 40vw; | |
| max-height: 630px; | |
| } | |
| .gradio_page[theme=default] .gradio_interface .output_image .image_preview_holder { | |
| background: transparent; | |
| } | |
| .panel:nth-child(1) { | |
| margin-left: 50px; | |
| margin-right: 50px; | |
| margin-bottom: 80px; | |
| max-width: 750px; | |
| } | |
| .panel { | |
| background: transparent; | |
| } | |
| .gradio_page[theme=default] .gradio_interface .component_set { | |
| background: transparent; | |
| } | |
| .panel:nth-child(2) .gradio_page[theme=default] .gradio_interface .panel_header { | |
| display: none; | |
| } | |
| .labels { | |
| height: 20px; | |
| width: auto; | |
| } | |
| """ | |
| creative_slider = gr.inputs.Radio(["Low", "Medium", "High"], default="Low", label='Creativity') | |
| textbox = gr.inputs.Textbox(placeholder='An apartment with two bedrooms and one bathroom', lines="3", | |
| label="DESCRIBE YOUR IDEAL APARTMENT") | |
| generated = gr.outputs.Image(label='Generated Layout') | |
| iface = gr.Interface(fn=prompt_to_layout, inputs=[textbox, creative_slider], | |
| outputs=[generated], | |
| css=custom_css, | |
| theme="default", | |
| allow_flagging=False, | |
| allow_screenshot=False, | |
| thumbnail="thumbnail_gradio.PNG") | |
| iface.launch() |