liguang0115 commited on
Commit
7b64bf4
·
1 Parent(s): 2df809d

Update README with project details and remove training script

Browse files
Files changed (2) hide show
  1. README.md +11 -76
  2. train.py +0 -263
README.md CHANGED
@@ -1,76 +1,11 @@
1
- <div align="center">
2
- <h1>VMem: Consistent Video Scene Generation with Surfel-Indexed View Memory</h1>
3
-
4
- <a href="https://v-mem.github.io/"><img src="https://img.shields.io/badge/%F0%9F%8F%A0%20Project%20Page-gray.svg"></a>
5
- <a href="http://arxiv.org/abs/2503.14489"><img src="https://img.shields.io/badge/%F0%9F%93%84%20arXiv-2503.14489-B31B1B.svg"></a>
6
- <a href="https://huggingface.co/liguang0115/vmem"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a>
7
- <a href="https://huggingface.co/spaces/stabilityai/stable-virtual-camera"><img src="https://img.shields.io/badge/%F0%9F%9A%80%20Gradio%20Demo-Huggingface-orange"></a>
8
-
9
- [Runjia Li](https://runjiali-rl.github.io/), [Philip Torr](https://www.robots.ox.ac.uk/~phst/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Tomas Jakab](https://www.robots.ox.ac.uk/~tomj/)
10
- <br>
11
- <br>
12
- [University of Oxford](https://www.robots.ox.ac.uk/~vgg/)
13
- </div>
14
-
15
- <p align="center">
16
- <img src="assets/demo_teaser.gif" width="100%" alt="Teaser" style="border-radius:10px;"/>
17
- </p>
18
-
19
- <!-- <p align="center" border-radius="10px">
20
- <img src="assets/benchmark.png" width="100%" alt="teaser_page1"/>
21
- </p> -->
22
-
23
- # Overview
24
-
25
- `VMem` is a plug-and-play memory mechanism of image-set models for consistent scene generation.
26
- Existing methods either rely on inpainting with explicit geometry estimation, which suffers from inaccuracies, or use limited context windows in video-based approaches, leading to poor long-term coherence. To overcome these issues, we introduce Surfel Memory of Views (VMem), which anchors past views to surface elements (surfels) they observed. This enables conditioning novel view generation on the most relevant past views rather than just the most recent ones, enhancing long-term scene consistency while reducing computational cost.
27
-
28
-
29
- # :wrench: Installation
30
-
31
- ```bash
32
- conda create -n vmem python=3.10
33
- conda activate vmem
34
- pip install -r requirements.txt
35
- ```
36
-
37
-
38
- # :rocket: Usage
39
-
40
- You need to properly authenticate with Hugging Face to download our model weights. Once set up, our code will handle it automatically at your first run. You can authenticate by running
41
-
42
- ```bash
43
- # This will prompt you to enter your Hugging Face credentials.
44
- huggingface-cli login
45
- ```
46
-
47
- Once authenticated, go to our model card [here](https://huggingface.co/stabilityai/stable-virtual-camera) and enter your information for access.
48
-
49
- We provide a demo for you to interact with `VMem`. Simply run
50
-
51
- ```bash
52
- python app.py
53
- ```
54
-
55
-
56
- ## :heart: Acknowledgement
57
- This work is built on top of [CUT3R](https://github.com/CUT3R/CUT3R), [DUSt3R](https://github.com/naver/dust3r) and [Stable Virtual Camera](https://github.com/stability-ai/stable-virtual-camera). We thank them for their great works.
58
-
59
-
60
-
61
-
62
-
63
- # :books: Citing
64
-
65
- If you find this repository useful, please consider giving a star :star: and citation.
66
-
67
- ```
68
- @article{zhou2025stable,
69
- title={Stable Virtual Camera: Generative View Synthesis with Diffusion Models},
70
- author={Jensen (Jinghao) Zhou and Hang Gao and Vikram Voleti and Aaryaman Vasishta and Chun-Han Yao and Mark Boss and
71
- Philip Torr and Christian Rupprecht and Varun Jampani
72
- },
73
- journal={arXiv preprint arXiv:2503.14489},
74
- year={2025}
75
- }
76
- ```
 
1
+ ---
2
+ title: Stable Virtual Camera
3
+ emoji: ⚡
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 5.33.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+ - **Project Page**: [https://v-mem.github.io/](https://v-mem.github.io/)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py DELETED
@@ -1,263 +0,0 @@
1
- import argparse
2
- from datetime import datetime
3
- import random
4
- import os
5
- import time
6
- import multiprocessing
7
-
8
- # Set multiprocessing start method to 'spawn' to avoid CUDA initialization issues in forked processes
9
- multiprocessing.set_start_method('spawn', force=True)
10
-
11
-
12
- from tqdm.auto import tqdm # Progress bar
13
- import numpy as np
14
- from omegaconf import OmegaConf
15
-
16
- import torch
17
- import torch.nn as nn
18
- from torch.utils.data import DataLoader
19
- from torch.optim.lr_scheduler import SequentialLR, LambdaLR, CosineAnnealingLR, ExponentialLR # Importing CosineAnnealingLR scheduler
20
- import torch.nn.functional as F
21
-
22
-
23
-
24
- from accelerate import Accelerator, DistributedDataParallelKwargs
25
- from accelerate.utils import set_seed # Removed get_scheduler import
26
-
27
- from peft import get_peft_model, LoraConfig
28
-
29
- from modeling import VMemModel
30
- from modeling.modules.autoencoder import AutoEncoder
31
- from modeling.sampling import DDPMDiscretization, DiscreteDenoiser, create_samplers
32
- from modeling.modules.conditioner import CLIPConditioner
33
-
34
- from utils.training_utils import DiffusionTrainer, load_pretrained_model
35
- from data.dataset import RealEstatePoseImageSevaDataset
36
-
37
-
38
-
39
-
40
- # set random seed for reproducibility
41
- torch.manual_seed(42)
42
- random.seed(42)
43
- np.random.seed(42)
44
-
45
-
46
-
47
- def parse_args():
48
- parser = argparse.ArgumentParser(description='Train a model')
49
- parser.add_argument('--config', type=str, default="", required=True, help='Path to the config file')
50
- args = parser.parse_args()
51
- return args
52
-
53
-
54
- def generate_current_datetime():
55
- return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
56
-
57
- def prepare_model(unet, config):
58
- assert isinstance(unet, VMemModel), "unet should be an instance of VMemModel"
59
- if config.training.lora_flag:
60
- target_modules = []
61
- for name, param in unet.named_parameters():
62
- # # if ("temporal" in name or "transformer" in name) and "norm" not in name:
63
- print(name)
64
- if ("transformer" in name or "emb" in name or "layers" in name) \
65
- and "norm" not in name and "in_layers.0" not in name and "out_layers.0" not in name:
66
- # print(name)
67
- name = name.replace(".weight", "")
68
- name = name.replace(".bias", "")
69
- if name not in target_modules:
70
- target_modules.append(str(name))
71
-
72
- lora_config = LoraConfig(
73
- r=config.training.lora_r,
74
- lora_alpha=config.training.lora_alpha,
75
- target_modules=target_modules,
76
- lora_dropout=config.training.lora_dropout,
77
- # bias="none",
78
- )
79
- lora_config.target_modules = target_modules
80
-
81
- unet = get_peft_model(unet, lora_config)
82
- # for name, param in unet.named_parameters():
83
- # if "camera" in name or "control" in name or "context" in name or "epipolar" in name or "appearance" in name:
84
- # print(name)
85
- # param.requires_grad = True
86
-
87
- unet.print_trainable_parameters()
88
- else:
89
- for name, param in unet.named_parameters():
90
- param.requires_grad = True
91
-
92
- print("trainable parameters percentage: ", np.sum([p.numel() for p in unet.parameters() if p.requires_grad])/np.sum([p.numel() for p in unet.parameters()]))
93
- return unet
94
-
95
-
96
-
97
-
98
- def main():
99
- args = parse_args()
100
- config_path = args.config
101
- config = OmegaConf.load(config_path)
102
-
103
- # Load the configuration
104
- num_epochs = config.training.num_epochs
105
- batch_size = config.training.batch_size
106
- learning_rate = config.training.learning_rate
107
- gradient_accumulation_steps = config.training.gradient_accumulation_steps
108
- num_workers = config.training.num_workers
109
- warmup_epochs = config.training.warmup_epochs
110
- max_grad_norm = config.training.max_grad_norm
111
- validation_interval = config.training.validation_interval
112
- visualization_flag = config.training.visualization_flag
113
- visualize_every = config.training.visualize_every
114
- random_seed = config.training.random_seed
115
- save_flag = config.training.save_flag
116
- use_wandb = config.training.use_wandb
117
- samples_dir = config.training.samples_dir
118
-
119
-
120
-
121
- weights_save_dir = config.training.weights_save_dir
122
-
123
-
124
- resume = config.training.resume
125
-
126
-
127
-
128
- exp_id = generate_current_datetime()
129
- if visualization_flag:
130
- run_visualization_dir = f"{samples_dir}/{exp_id}"
131
- os.makedirs(run_visualization_dir, exist_ok=True)
132
- else:
133
- run_visualization_dir = None
134
- if save_flag:
135
- run_weights_save_dir = f"{weights_save_dir}/{exp_id}"
136
- os.makedirs(run_weights_save_dir, exist_ok=True)
137
- else:
138
- run_weights_save_dir = None
139
-
140
-
141
- accelerator = Accelerator(
142
- mixed_precision="fp16",
143
- gradient_accumulation_steps=gradient_accumulation_steps,
144
- kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=False)],
145
- )
146
- num_gpus = accelerator.num_processes
147
-
148
- if random_seed is not None:
149
- set_seed(random_seed, device_specific=True)
150
- device = accelerator.device
151
-
152
-
153
-
154
- model = load_pretrained_model(cache_dir=config.model.cache_dir, device=device)
155
-
156
-
157
- model = prepare_model(model, config)
158
- if resume:
159
- model.load_state_dict(torch.load(resume, map_location='cpu'), strict=False)
160
- torch.cuda.empty_cache()
161
-
162
- # model = model.to(device)
163
-
164
-
165
- # time.sleep(100*3600)
166
-
167
-
168
-
169
- train_dataset = RealEstatePoseImageSevaDataset(rgb_data_dir=config.dataset.realestate10k.rgb_data_dir,
170
- meta_info_dir=config.dataset.realestate10k.meta_info_dir,
171
- num_sample_per_episode=config.dataset.realestate10k.num_sample_per_episode,
172
- mode='train')
173
- val_dataset = RealEstatePoseImageSevaDataset(rgb_data_dir=config.dataset.realestate10k.rgb_data_dir,
174
- meta_info_dir=config.dataset.realestate10k.meta_info_dir,
175
- num_sample_per_episode=config.dataset.realestate10k.val_num_sample_per_episode,
176
- mode='test')
177
-
178
-
179
- train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, multiprocessing_context='spawn')
180
- val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, multiprocessing_context='spawn')
181
-
182
- optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=config.training.weight_decay)
183
- train_steps_per_epoch = len(train_dataloader)
184
- total_train_steps = num_epochs * train_steps_per_epoch
185
- warmup_steps = warmup_epochs * train_steps_per_epoch
186
-
187
- lr_scheduler = CosineAnnealingLR(
188
- optimizer, T_max=total_train_steps - warmup_steps, eta_min=0
189
- )
190
-
191
- # lr_scheduler = ExponentialLR(optimizer, gamma=gamma)
192
- if warmup_epochs > 0:
193
- def warmup_lambda(current_step):
194
- return float(current_step) / float(max(1, warmup_steps))
195
- warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_lambda)
196
-
197
-
198
- # Combine the schedulers using SequentialLR
199
- lr_scheduler = SequentialLR(
200
- optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[warmup_steps]
201
- )
202
- vae = AutoEncoder(chunk_size=1).to(device)
203
- vae.eval()
204
- conditioner = CLIPConditioner().to(device)
205
- discretization = DDPMDiscretization()
206
- denoiser = DiscreteDenoiser(discretization=discretization, num_idx=1000, device=device)
207
- sampler = create_samplers(guider_types=config.training.guider_types,
208
- discretization=discretization,
209
- num_frames=config.model.num_frames,
210
- num_steps=config.training.inference_num_steps,
211
- cfg_min=config.training.cfg_min,
212
- device=device)
213
-
214
-
215
- (model,
216
- vae,
217
- train_dataloader,
218
- val_dataloader,
219
- optimizer,
220
- lr_scheduler) = accelerator.prepare(
221
- model,
222
- vae,
223
- train_dataloader,
224
- val_dataloader,
225
- optimizer,
226
- lr_scheduler,
227
- )
228
-
229
-
230
- trainer = DiffusionTrainer(network=model,
231
- ae=vae,
232
- conditioner=conditioner,
233
- denoiser=denoiser,
234
- sampler=sampler,
235
- discretization=discretization,
236
- cfg=config.training.cfg,
237
- optimizer=optimizer,
238
- lr_scheduler=lr_scheduler,
239
- ema_decay=config.training.ema_decay,
240
- device=device,
241
- accelerator=accelerator,
242
- max_grad_norm=max_grad_norm,
243
- save_flag=save_flag,
244
- visualize_flag=visualization_flag)
245
-
246
-
247
-
248
- trainer.train(train_dataloader,
249
- num_epochs,
250
- unconditional_prob=config.training.uncond_prob,
251
- log_every=10,
252
- validation_dataloader=val_dataloader,
253
- validation_interval=validation_interval,
254
- save_dir=run_weights_save_dir,
255
- save_interval=config.training.save_every,
256
- visualize_every=visualize_every,
257
- visualize_dir=run_visualization_dir,
258
- use_wandb=use_wandb)
259
-
260
-
261
- if __name__ == "__main__":
262
- main()
263
-