santoshtyss commited on
Commit
1f54647
·
1 Parent(s): c22dffd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
2
+ from diffusers import StableDiffusionInpaintPipeline,StableDiffusionPipeline
3
+ from PIL import Image
4
+ import requests
5
+
6
+ import cv2
7
+ import torch
8
+ import matplotlib.pyplot as plt
9
+
10
+ import io
11
+ import requests
12
+ from huggingface_hub import notebook_login
13
+
14
+ import os
15
+
16
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
17
+
18
+
19
+
20
+ os.system('pip install git+https://github.com/huggingface/transformers -q')
21
+ os.system('pip install git+https://github.com/huggingface/diffusers.git -q')
22
+ os.system('pip install accelerate')
23
+ os.system('pip install transformers[sentencepiece]')
24
+ os.system('pip install Pillow')
25
+ os.system('pip install gradio')
26
+
27
+
28
+ notebook_login()
29
+
30
+
31
+
32
+ processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
33
+ model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
34
+
35
+
36
+ device = "cuda"
37
+ IPmodel_path = "runwayml/stable-diffusion-inpainting"
38
+
39
+ IPpipe = StableDiffusionInpaintPipeline.from_pretrained(
40
+ IPmodel_path,
41
+ revision="fp16",
42
+ torch_dtype=torch.float16,
43
+ ).to(device)
44
+
45
+ trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
46
+ trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
47
+
48
+
49
+ SDpipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", revision="fp16", torch_dtype=torch.float16, use_auth_token=True).to(device)
50
+
51
+
52
+ def create_mask(image, prompt):
53
+ inputs = processor(text=[prompt], images=[image], padding="max_length", return_tensors="pt")
54
+ # predict
55
+ with torch.no_grad():
56
+ outputs = model(**inputs)
57
+
58
+ preds = outputs.logits
59
+
60
+ filename = f"mask.png"
61
+ plt.imsave(filename,torch.sigmoid(preds))
62
+
63
+ gray_image = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2GRAY)
64
+
65
+ (thresh, bw_image) = cv2.threshold(gray_image, 100, 255, cv2.THRESH_BINARY)
66
+
67
+ # For debugging only:
68
+ # cv2.imwrite(filename,bw_image)
69
+
70
+ # fix color format
71
+ cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
72
+
73
+ mask = cv2.bitwise_not(bw_image)
74
+ cv2.imwrite(filename, mask)
75
+
76
+ return Image.open('mask.png')
77
+
78
+
79
+
80
+
81
+ def generate_image(image, product_name, target_name):
82
+ mask = create_mask(image, product_name)
83
+ image = image.resize((512, 512))
84
+ mask = mask.resize((512,512))
85
+ guidance_scale=8
86
+ #guidance_scale=16
87
+ num_samples = 4
88
+
89
+ prompt = target_name
90
+ generator = torch.Generator(device="cuda").manual_seed(22) # change the seed to get different results
91
+
92
+ im = IPpipe(
93
+ prompt=prompt,
94
+ image=image,
95
+ mask_image=mask,
96
+ guidance_scale=guidance_scale,
97
+ generator=generator,
98
+ num_images_per_prompt=num_samples,
99
+ ).images
100
+
101
+ return im
102
+
103
+
104
+
105
+ def translate_sentence(article, source, target):
106
+ if target == 'eng_Latn':
107
+ return article
108
+ translator = pipeline('translation', model=trans_model, tokenizer=trans_tokenizer, src_lang=source, tgt_lang=target, device=0)
109
+ output = translator(article, max_length=400)
110
+ output = output[0]['translation_text']
111
+ return output
112
+
113
+
114
+ codes_as_string = codes_as_string.split('\n')
115
+
116
+ flores_codes = {}
117
+ for code in codes_as_string:
118
+ lang, lang_code = code.split('\t')
119
+ flores_codes[lang] = lang_code
120
+
121
+
122
+
123
+ import gradio as gr
124
+ import gc
125
+ gc.collect()
126
+ %env PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:256
127
+ image_label = 'Please upload the image (optional)'
128
+ extract_label = 'Specify what need to be extracted from the above image'
129
+ prompt_label = 'Specify the description of image to be generated'
130
+ button_label = "Proceed"
131
+ output_label = "Generations"
132
+
133
+
134
+ shot_services = ['close-up', 'extreme-closeup', 'POV','medium', 'long']
135
+ shot_label = 'Choose the shot type'
136
+
137
+ style_services = ['polaroid', 'monochrome', 'long exposure','color splash', 'Tilt shift']
138
+ style_label = 'Choose the style type'
139
+
140
+ lighting_services = ['soft', 'ambivalent', 'ring','sun', 'cinematic']
141
+ lighting_label = 'Choose the lighting type'
142
+
143
+ context_services = ['indoor', 'outdoor', 'at night','in the park', 'in the beach','studio']
144
+ context_label = 'Choose the context'
145
+
146
+ lens_services = ['wide angle', 'telephoto', '24 mm','EF 70mm', 'Bokeh']
147
+ lens_label = 'Choose the lens type'
148
+
149
+ device_services = ['iphone', 'CCTV', 'Nikon ZFX','Canon', 'Gopro']
150
+ device_label = 'Choose the device type'
151
+
152
+
153
+ def change_lang(choice):
154
+ global lang_choice
155
+ lang_choice = choice
156
+ new_image_label = translate_sentence(image_label, "english", choice)
157
+ return [gr.update(visible=True, label=translate_sentence(image_label, flores_codes["English"],flores_codes[choice])),
158
+ gr.update(visible=True, label=translate_sentence(extract_label, flores_codes["English"],flores_codes[choice])),
159
+ gr.update(visible=True, label=translate_sentence(prompt_label, flores_codes["English"],flores_codes[choice])),
160
+ gr.update(visible=True, value=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])),
161
+ gr.update(visible=True, label=translate_sentence(button_label, flores_codes["English"],flores_codes[choice])),
162
+ ]
163
+
164
+ def add_to_prompt(prompt_text,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio ):
165
+ if shot_radio != '':
166
+ prompt_text += ","+shot_radio
167
+ if style_radio != '':
168
+ prompt_text += ","+style_radio
169
+ if lighting_radio != '':
170
+ prompt_text += ","+lighting_radio
171
+ if context_radio != '':
172
+ prompt_text += ","+ context_radio
173
+ if lens_radio != '':
174
+ prompt_text += ","+ lens_radio
175
+ if device_radio != '':
176
+ prompt_text += ","+ device_radio
177
+ return prompt_text
178
+
179
+ def proceed_with_generation(input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio):
180
+ if extract_text == "" or input_file == "":
181
+ translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"])
182
+ translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio)
183
+ print(translated_prompt)
184
+ output = SDpipe(translated_prompt, height=512, width=512, num_images_per_prompt=4, device=0)
185
+ return output.images
186
+ elif extract_text != "" and input_file == "" and prompt_text !='':
187
+ translated_prompt = translate_sentence(prompt_text, flores_codes[lang_choice], flores_codes["English"])
188
+ translated_prompt = add_to_prompt(translated_prompt,shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio)
189
+ print(translated_prompt)
190
+ translated_extract = translate_sentence(extract_text, flores_codes[lang_choice], flores_codes["English"])
191
+ print(translated_extract)
192
+ output = generate_image(Image.fromarray(input_file), translated_extract, translated_prompt)
193
+ return output
194
+ else:
195
+ raise gr.Error("Please fill all details for guided image or atleast promt for free image rendition !")
196
+
197
+
198
+
199
+ with gr.Blocks() as demo:
200
+
201
+ lang_option = gr.Dropdown(list(flores_codes.keys()), default='English', label='Please Select your Language')
202
+
203
+ with gr.Row():
204
+ input_file = gr.Image(interactive = True, label=image_label, visible=False, shape=(512,512))
205
+ extract_text = gr.Textbox(label= extract_label, lines=1, interactive = True, visible = False)
206
+ prompt_text = gr.Textbox(label= prompt_label, lines=1, interactive = True, visible = False)
207
+
208
+ with gr.Accordion("Advanced Options"):
209
+ shot_radio = gr.Radio(shot_services , label=shot_label)
210
+ style_radio = gr.Radio(style_services , label=style_label)
211
+ lighting_radio = gr.Radio(lighting_services , label=lighting_label)
212
+ context_radio = gr.Radio(context_services , label=context_label)
213
+ lens_radio = gr.Radio(lens_services , label=lens_label)
214
+ device_radio = gr.Radio(device_services , label=device_label)
215
+
216
+ button = gr.Button(value = button_label , visible = False)
217
+
218
+ with gr.Row():
219
+ output_gallery = gr.Gallery(label = output_label, visible= False)
220
+
221
+
222
+
223
+
224
+ lang_option.change(fn=change_lang, inputs=lang_option, outputs=[input_file, extract_text, prompt_text, button, output_gallery])
225
+ button.click( proceed_with_generation, [input_file, extract_text, prompt_text, shot_radio, style_radio, lighting_radio, context_radio, lens_radio, device_radio], [output_gallery])
226
+
227
+
228
+ demo.launch(debug=True, share=True)