vmem / extern /CUT3R /datasets_preprocess /preprocess_spring.py
liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
#!/usr/bin/env python3
"""
Preprocessing Script for Spring Dataset
This script:
- Recursively processes each sequence in a given 'root_dir' for the Spring dataset.
- Reads RGB, disparity, optical flow files, and camera intrinsics/extrinsics.
- Converts disparity to depth, rescales images/flows, and writes processed results
(RGB, Depth, Cam intrinsics/poses, Forward Flow, Backward Flow) to 'out_dir'.
Usage:
python preprocess_spring.py \
--root_dir /path/to/spring/train \
--out_dir /path/to/processed_spring \
--baseline 0.065 \
--output_size 960 540
"""
import os
import argparse
import numpy as np
import cv2
from PIL import Image
import shutil
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
# Custom modules (adapt these imports to your actual module locations)
import flow_IO
import src.dust3r.datasets.utils.cropping as cropping
def rescale_flow(flow, size):
"""
Resize an optical flow field to a new resolution and scale its vectors accordingly.
Args:
flow (np.ndarray): Flow array of shape [H, W, 2].
size (tuple): Desired (width, height) for the resized flow.
Returns:
np.ndarray: Resized and scaled flow array.
"""
h, w = flow.shape[:2]
new_w, new_h = size
# Resize the flow map
flow_resized = cv2.resize(
flow.astype("float32"), (new_w, new_h), interpolation=cv2.INTER_LINEAR
)
# Scale the flow vectors to match the new resolution
flow_resized[..., 0] *= new_w / w
flow_resized[..., 1] *= new_h / h
return flow_resized
def get_depth(disparity, fx_baseline):
"""
Convert disparity to depth using baseline * focal_length / disparity.
Args:
disparity (np.ndarray): Disparity map (same resolution as the RGB).
fx_baseline (float): Product of the focal length (fx) and baseline.
Returns:
np.ndarray: Depth map.
"""
# Avoid divide-by-zero
depth = np.zeros_like(disparity, dtype=np.float32)
valid_mask = disparity != 0
depth[valid_mask] = fx_baseline / disparity[valid_mask]
return depth
def process_sequence(seq, root_dir, out_dir, baseline, output_size):
"""
Process a single sequence from the Spring dataset:
- Reads RGB frames, disparity maps, forward/backward optical flow, intrinsics, extrinsics.
- Converts disparity to depth.
- Rescales images, depth, and flow to the specified 'output_size'.
- Saves the processed data to the output directory.
Args:
seq (str): Name of the sequence (subdirectory).
root_dir (str): Root directory containing the Spring dataset sequences.
out_dir (str): Output directory to store processed files.
baseline (float): Stereo baseline for disparity-to-depth conversion (SPRING_BASELINE).
output_size (tuple): (width, height) for output images and flows.
Returns:
None or str:
- Returns None if processing is successful.
- Returns an error message (str) if an error occurs.
"""
seq_dir = os.path.join(root_dir, seq)
img_dir = os.path.join(seq_dir, "frame_left")
disp1_dir = os.path.join(seq_dir, "disp1_left")
fflow_dir = os.path.join(seq_dir, "flow_FW_left")
bflow_dir = os.path.join(seq_dir, "flow_BW_left")
intrinsics_path = os.path.join(seq_dir, "cam_data", "intrinsics.txt")
extrinsics_path = os.path.join(seq_dir, "cam_data", "extrinsics.txt")
try:
# Check required files/folders
for path in (
img_dir,
disp1_dir,
fflow_dir,
bflow_dir,
intrinsics_path,
extrinsics_path,
):
if not os.path.exists(path):
return f"Missing required path: {path}"
# Prepare output directories
out_img_dir = os.path.join(out_dir, seq, "rgb")
out_depth_dir = os.path.join(out_dir, seq, "depth")
out_cam_dir = os.path.join(out_dir, seq, "cam")
out_fflow_dir = os.path.join(out_dir, seq, "flow_forward")
out_bflow_dir = os.path.join(out_dir, seq, "flow_backward")
for d in [
out_img_dir,
out_depth_dir,
out_cam_dir,
out_fflow_dir,
out_bflow_dir,
]:
os.makedirs(d, exist_ok=True)
# Read camera data
all_intrinsics = np.loadtxt(intrinsics_path)
all_extrinsics = np.loadtxt(extrinsics_path)
# Collect filenames
rgbs = sorted([f for f in os.listdir(img_dir) if f.endswith(".png")])
disps = sorted([f for f in os.listdir(disp1_dir) if f.endswith(".dsp5")])
fflows = sorted([f for f in os.listdir(fflow_dir) if f.endswith(".flo5")])
bflows = sorted([f for f in os.listdir(bflow_dir) if f.endswith(".flo5")])
# Basic consistency check
if not (len(all_intrinsics) == len(all_extrinsics) == len(rgbs) == len(disps)):
return (
f"Inconsistent lengths in {seq}: "
f"Intrinsics {len(all_intrinsics)}, "
f"Extrinsics {len(all_extrinsics)}, "
f"RGBs {len(rgbs)}, "
f"Disparities {len(disps)}"
)
# Note: fflows+1 == len(all_intrinsics), bflows+1 == len(all_intrinsics)
# Check if already processed
if len(os.listdir(out_img_dir)) == len(rgbs):
return None # Already done, skip
# Process each frame
for i in tqdm(
range(len(all_intrinsics)), desc=f"Processing {seq}", leave=False
):
frame_num = i + 1 # frames appear as 1-based in filenames
img_path = os.path.join(img_dir, f"frame_left_{frame_num:04d}.png")
disp1_path = os.path.join(disp1_dir, f"disp1_left_{frame_num:04d}.dsp5")
fflow_path = None
bflow_path = None
if i < len(all_intrinsics) - 1:
fflow_path = os.path.join(
fflow_dir, f"flow_FW_left_{frame_num:04d}.flo5"
)
if i > 0:
bflow_path = os.path.join(
bflow_dir, f"flow_BW_left_{frame_num:04d}.flo5"
)
# Load image
image = Image.open(img_path).convert("RGB")
# Build the intrinsics matrix
K = np.eye(3, dtype=np.float32)
K[0, 0] = all_intrinsics[i][0] # fx
K[1, 1] = all_intrinsics[i][1] # fy
K[0, 2] = all_intrinsics[i][2] # cx
K[1, 2] = all_intrinsics[i][3] # cy
# Build the pose
cam_ext = all_extrinsics[i].reshape(4, 4)
pose = np.linalg.inv(cam_ext).astype(np.float32)
if np.any(np.isinf(pose)) or np.any(np.isnan(pose)):
return f"Invalid pose for frame {i} in {seq}"
# Load disparity
disp1 = flow_IO.readDispFile(disp1_path)
# Subsample by 2
disp1 = disp1[::2, ::2]
# Convert disparity to depth
fx_baseline = all_intrinsics[i][0] * baseline # fx * baseline
depth = get_depth(disp1, fx_baseline)
depth[np.isinf(depth)] = 0.0
depth[np.isnan(depth)] = 0.0
# Load optical flows if available
fflow = None
bflow = None
if fflow_path and os.path.exists(fflow_path):
fflow = flow_IO.readFlowFile(fflow_path)
fflow = fflow[::2, ::2]
if bflow_path and os.path.exists(bflow_path):
bflow = flow_IO.readFlowFile(bflow_path)
bflow = bflow[::2, ::2]
# Rescale image, depth, and intrinsics
image, depth, K_scaled = cropping.rescale_image_depthmap(
image, depth, K, output_size
)
W_new, H_new = image.size # after rescale_image_depthmap
# Rescale forward/backward flow
if fflow is not None:
fflow = rescale_flow(fflow, (W_new, H_new))
if bflow is not None:
bflow = rescale_flow(bflow, (W_new, H_new))
# Save output
out_index_str = f"{i:04d}"
out_img_path = os.path.join(out_img_dir, out_index_str + ".png")
image.save(out_img_path)
out_depth_path = os.path.join(out_depth_dir, out_index_str + ".npy")
np.save(out_depth_path, depth)
out_cam_path = os.path.join(out_cam_dir, out_index_str + ".npz")
np.savez(out_cam_path, intrinsics=K_scaled, pose=pose)
if fflow is not None:
out_fflow_path = os.path.join(out_fflow_dir, out_index_str + ".npy")
np.save(out_fflow_path, fflow)
if bflow is not None:
out_bflow_path = os.path.join(out_bflow_dir, out_index_str + ".npy")
np.save(out_bflow_path, bflow)
except Exception as e:
return f"Error processing sequence {seq}: {e}"
return None # success
def main():
parser = argparse.ArgumentParser(description="Preprocess Spring dataset.")
parser.add_argument(
"--root_dir",
required=True,
help="Path to the root directory containing Spring dataset sequences.",
)
parser.add_argument(
"--out_dir",
required=True,
help="Path to the output directory where processed files will be saved.",
)
parser.add_argument(
"--baseline",
type=float,
default=0.065,
help="Stereo baseline for disparity-to-depth conversion (default: 0.065).",
)
parser.add_argument(
"--output_size",
type=int,
nargs=2,
default=[960, 540],
help="Output image size (width height) for rescaling.",
)
args = parser.parse_args()
# Gather sequences
if not os.path.isdir(args.root_dir):
raise ValueError(f"Root directory not found: {args.root_dir}")
os.makedirs(args.out_dir, exist_ok=True)
seqs = sorted(
[
d
for d in os.listdir(args.root_dir)
if os.path.isdir(os.path.join(args.root_dir, d))
]
)
if not seqs:
raise ValueError(f"No valid sequence folders found in {args.root_dir}")
# Process each sequence in parallel
with ProcessPoolExecutor(max_workers=os.cpu_count() // 2) as executor:
future_to_seq = {
executor.submit(
process_sequence,
seq,
args.root_dir,
args.out_dir,
args.baseline,
args.output_size,
): seq
for seq in seqs
}
for future in tqdm(
as_completed(future_to_seq),
total=len(future_to_seq),
desc="Processing all sequences",
):
seq = future_to_seq[future]
error = future.result()
if error:
print(f"Sequence '{seq}' failed: {error}")
if __name__ == "__main__":
main()