liuhuohuo commited on
Commit
81ea85b
1 Parent(s): 729d3c5
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -13,6 +13,7 @@ from einops import repeat
13
  import torchvision.transforms as transforms
14
  from torchvision.utils import make_grid
15
  from utils.utils import instantiate_from_config
 
16
 
17
  from collections import OrderedDict
18
 
@@ -108,6 +109,7 @@ def infer(image, prompt, infer_type='image', seed=123, style_strength=1.0, steps
108
  torchvision.transforms.Lambda(lambda x: x * 2. - 1.),
109
  ])
110
 
 
111
  style_img = style_transforms(image).unsqueeze(0).cuda()
112
  style_cond = model.get_batch_style(style_img)
113
  append_to_context = model.adapter(style_cond)
 
13
  import torchvision.transforms as transforms
14
  from torchvision.utils import make_grid
15
  from utils.utils import instantiate_from_config
16
+ from PIL import Image
17
 
18
  from collections import OrderedDict
19
 
 
109
  torchvision.transforms.Lambda(lambda x: x * 2. - 1.),
110
  ])
111
 
112
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
113
  style_img = style_transforms(image).unsqueeze(0).cuda()
114
  style_cond = model.get_batch_style(style_img)
115
  append_to_context = model.adapter(style_cond)