nateraw commited on
Commit
4c42dda
·
1 Parent(s): ae29498

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ model_map = torch.hub.load('nateraw/image-generation:main', 'model_map')
4
+ class InferenceWrapper:
5
+ def __init__(self, model):
6
+ self.model = model
7
+ self.pipe = torch.hub.load('nateraw/image-generation:main', 'styleganv3', pretrained=self.model, videos=True)
8
+ def __call__(self, seed1, seed2, seed3, w_frames, model):
9
+ if model != self.model:
10
+ print(f"Loading model: {model}")
11
+ self.model = model
12
+ self.pipe = torch.hub.load('nateraw/image-generation:main', 'styleganv3', pretrained=self.model, videos=True)
13
+ else:
14
+ print(f"Model '{model}' already loaded, reusing it.")
15
+ return self.pipe([seed1, seed2, seed3], w_frames=w_frames)
16
+ wrapper = InferenceWrapper('wikiart-1024')
17
+ def fn(seed, model):
18
+ return wrapper(seed, model)
19
+ gr.Interface(
20
+ fn,
21
+ inputs=[
22
+ gr.inputs.Slider(minimum=0, maximum=999999999, step=1, default=0, label='Random Seed For Image 1'),
23
+ gr.inputs.Slider(minimum=0, maximum=999999999, step=1, default=0, label='Random Seed For Image 2'),
24
+ gr.inputs.Slider(minimum=0, maximum=999999999, step=1, default=0, label='Random Seed For Image 3'),
25
+ gr.inputs.Radio([60, 120, 240], type="value", default=60, label='Frames'),
26
+ gr.inputs.Radio(list(model_map), type="value", default='wikiart-1024', label='Pretrained Model')
27
+ ],
28
+ outputs='image',
29
+ examples=[[0, 1, 2, 60, 'wikiart-1024']],
30
+ enable_queue=True
31
+ ).launch()