Spaces:
Running
on
L4
Running
on
L4
#!/usr/bin/env python3 | |
""" | |
Preprocessing Script for Spring Dataset | |
This script: | |
- Recursively processes each sequence in a given 'root_dir' for the Spring dataset. | |
- Reads RGB, disparity, optical flow files, and camera intrinsics/extrinsics. | |
- Converts disparity to depth, rescales images/flows, and writes processed results | |
(RGB, Depth, Cam intrinsics/poses, Forward Flow, Backward Flow) to 'out_dir'. | |
Usage: | |
python preprocess_spring.py \ | |
--root_dir /path/to/spring/train \ | |
--out_dir /path/to/processed_spring \ | |
--baseline 0.065 \ | |
--output_size 960 540 | |
""" | |
import os | |
import argparse | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import shutil | |
from tqdm import tqdm | |
from concurrent.futures import ProcessPoolExecutor, as_completed | |
# Custom modules (adapt these imports to your actual module locations) | |
import flow_IO | |
import src.dust3r.datasets.utils.cropping as cropping | |
def rescale_flow(flow, size): | |
""" | |
Resize an optical flow field to a new resolution and scale its vectors accordingly. | |
Args: | |
flow (np.ndarray): Flow array of shape [H, W, 2]. | |
size (tuple): Desired (width, height) for the resized flow. | |
Returns: | |
np.ndarray: Resized and scaled flow array. | |
""" | |
h, w = flow.shape[:2] | |
new_w, new_h = size | |
# Resize the flow map | |
flow_resized = cv2.resize( | |
flow.astype("float32"), (new_w, new_h), interpolation=cv2.INTER_LINEAR | |
) | |
# Scale the flow vectors to match the new resolution | |
flow_resized[..., 0] *= new_w / w | |
flow_resized[..., 1] *= new_h / h | |
return flow_resized | |
def get_depth(disparity, fx_baseline): | |
""" | |
Convert disparity to depth using baseline * focal_length / disparity. | |
Args: | |
disparity (np.ndarray): Disparity map (same resolution as the RGB). | |
fx_baseline (float): Product of the focal length (fx) and baseline. | |
Returns: | |
np.ndarray: Depth map. | |
""" | |
# Avoid divide-by-zero | |
depth = np.zeros_like(disparity, dtype=np.float32) | |
valid_mask = disparity != 0 | |
depth[valid_mask] = fx_baseline / disparity[valid_mask] | |
return depth | |
def process_sequence(seq, root_dir, out_dir, baseline, output_size): | |
""" | |
Process a single sequence from the Spring dataset: | |
- Reads RGB frames, disparity maps, forward/backward optical flow, intrinsics, extrinsics. | |
- Converts disparity to depth. | |
- Rescales images, depth, and flow to the specified 'output_size'. | |
- Saves the processed data to the output directory. | |
Args: | |
seq (str): Name of the sequence (subdirectory). | |
root_dir (str): Root directory containing the Spring dataset sequences. | |
out_dir (str): Output directory to store processed files. | |
baseline (float): Stereo baseline for disparity-to-depth conversion (SPRING_BASELINE). | |
output_size (tuple): (width, height) for output images and flows. | |
Returns: | |
None or str: | |
- Returns None if processing is successful. | |
- Returns an error message (str) if an error occurs. | |
""" | |
seq_dir = os.path.join(root_dir, seq) | |
img_dir = os.path.join(seq_dir, "frame_left") | |
disp1_dir = os.path.join(seq_dir, "disp1_left") | |
fflow_dir = os.path.join(seq_dir, "flow_FW_left") | |
bflow_dir = os.path.join(seq_dir, "flow_BW_left") | |
intrinsics_path = os.path.join(seq_dir, "cam_data", "intrinsics.txt") | |
extrinsics_path = os.path.join(seq_dir, "cam_data", "extrinsics.txt") | |
try: | |
# Check required files/folders | |
for path in ( | |
img_dir, | |
disp1_dir, | |
fflow_dir, | |
bflow_dir, | |
intrinsics_path, | |
extrinsics_path, | |
): | |
if not os.path.exists(path): | |
return f"Missing required path: {path}" | |
# Prepare output directories | |
out_img_dir = os.path.join(out_dir, seq, "rgb") | |
out_depth_dir = os.path.join(out_dir, seq, "depth") | |
out_cam_dir = os.path.join(out_dir, seq, "cam") | |
out_fflow_dir = os.path.join(out_dir, seq, "flow_forward") | |
out_bflow_dir = os.path.join(out_dir, seq, "flow_backward") | |
for d in [ | |
out_img_dir, | |
out_depth_dir, | |
out_cam_dir, | |
out_fflow_dir, | |
out_bflow_dir, | |
]: | |
os.makedirs(d, exist_ok=True) | |
# Read camera data | |
all_intrinsics = np.loadtxt(intrinsics_path) | |
all_extrinsics = np.loadtxt(extrinsics_path) | |
# Collect filenames | |
rgbs = sorted([f for f in os.listdir(img_dir) if f.endswith(".png")]) | |
disps = sorted([f for f in os.listdir(disp1_dir) if f.endswith(".dsp5")]) | |
fflows = sorted([f for f in os.listdir(fflow_dir) if f.endswith(".flo5")]) | |
bflows = sorted([f for f in os.listdir(bflow_dir) if f.endswith(".flo5")]) | |
# Basic consistency check | |
if not (len(all_intrinsics) == len(all_extrinsics) == len(rgbs) == len(disps)): | |
return ( | |
f"Inconsistent lengths in {seq}: " | |
f"Intrinsics {len(all_intrinsics)}, " | |
f"Extrinsics {len(all_extrinsics)}, " | |
f"RGBs {len(rgbs)}, " | |
f"Disparities {len(disps)}" | |
) | |
# Note: fflows+1 == len(all_intrinsics), bflows+1 == len(all_intrinsics) | |
# Check if already processed | |
if len(os.listdir(out_img_dir)) == len(rgbs): | |
return None # Already done, skip | |
# Process each frame | |
for i in tqdm( | |
range(len(all_intrinsics)), desc=f"Processing {seq}", leave=False | |
): | |
frame_num = i + 1 # frames appear as 1-based in filenames | |
img_path = os.path.join(img_dir, f"frame_left_{frame_num:04d}.png") | |
disp1_path = os.path.join(disp1_dir, f"disp1_left_{frame_num:04d}.dsp5") | |
fflow_path = None | |
bflow_path = None | |
if i < len(all_intrinsics) - 1: | |
fflow_path = os.path.join( | |
fflow_dir, f"flow_FW_left_{frame_num:04d}.flo5" | |
) | |
if i > 0: | |
bflow_path = os.path.join( | |
bflow_dir, f"flow_BW_left_{frame_num:04d}.flo5" | |
) | |
# Load image | |
image = Image.open(img_path).convert("RGB") | |
# Build the intrinsics matrix | |
K = np.eye(3, dtype=np.float32) | |
K[0, 0] = all_intrinsics[i][0] # fx | |
K[1, 1] = all_intrinsics[i][1] # fy | |
K[0, 2] = all_intrinsics[i][2] # cx | |
K[1, 2] = all_intrinsics[i][3] # cy | |
# Build the pose | |
cam_ext = all_extrinsics[i].reshape(4, 4) | |
pose = np.linalg.inv(cam_ext).astype(np.float32) | |
if np.any(np.isinf(pose)) or np.any(np.isnan(pose)): | |
return f"Invalid pose for frame {i} in {seq}" | |
# Load disparity | |
disp1 = flow_IO.readDispFile(disp1_path) | |
# Subsample by 2 | |
disp1 = disp1[::2, ::2] | |
# Convert disparity to depth | |
fx_baseline = all_intrinsics[i][0] * baseline # fx * baseline | |
depth = get_depth(disp1, fx_baseline) | |
depth[np.isinf(depth)] = 0.0 | |
depth[np.isnan(depth)] = 0.0 | |
# Load optical flows if available | |
fflow = None | |
bflow = None | |
if fflow_path and os.path.exists(fflow_path): | |
fflow = flow_IO.readFlowFile(fflow_path) | |
fflow = fflow[::2, ::2] | |
if bflow_path and os.path.exists(bflow_path): | |
bflow = flow_IO.readFlowFile(bflow_path) | |
bflow = bflow[::2, ::2] | |
# Rescale image, depth, and intrinsics | |
image, depth, K_scaled = cropping.rescale_image_depthmap( | |
image, depth, K, output_size | |
) | |
W_new, H_new = image.size # after rescale_image_depthmap | |
# Rescale forward/backward flow | |
if fflow is not None: | |
fflow = rescale_flow(fflow, (W_new, H_new)) | |
if bflow is not None: | |
bflow = rescale_flow(bflow, (W_new, H_new)) | |
# Save output | |
out_index_str = f"{i:04d}" | |
out_img_path = os.path.join(out_img_dir, out_index_str + ".png") | |
image.save(out_img_path) | |
out_depth_path = os.path.join(out_depth_dir, out_index_str + ".npy") | |
np.save(out_depth_path, depth) | |
out_cam_path = os.path.join(out_cam_dir, out_index_str + ".npz") | |
np.savez(out_cam_path, intrinsics=K_scaled, pose=pose) | |
if fflow is not None: | |
out_fflow_path = os.path.join(out_fflow_dir, out_index_str + ".npy") | |
np.save(out_fflow_path, fflow) | |
if bflow is not None: | |
out_bflow_path = os.path.join(out_bflow_dir, out_index_str + ".npy") | |
np.save(out_bflow_path, bflow) | |
except Exception as e: | |
return f"Error processing sequence {seq}: {e}" | |
return None # success | |
def main(): | |
parser = argparse.ArgumentParser(description="Preprocess Spring dataset.") | |
parser.add_argument( | |
"--root_dir", | |
required=True, | |
help="Path to the root directory containing Spring dataset sequences.", | |
) | |
parser.add_argument( | |
"--out_dir", | |
required=True, | |
help="Path to the output directory where processed files will be saved.", | |
) | |
parser.add_argument( | |
"--baseline", | |
type=float, | |
default=0.065, | |
help="Stereo baseline for disparity-to-depth conversion (default: 0.065).", | |
) | |
parser.add_argument( | |
"--output_size", | |
type=int, | |
nargs=2, | |
default=[960, 540], | |
help="Output image size (width height) for rescaling.", | |
) | |
args = parser.parse_args() | |
# Gather sequences | |
if not os.path.isdir(args.root_dir): | |
raise ValueError(f"Root directory not found: {args.root_dir}") | |
os.makedirs(args.out_dir, exist_ok=True) | |
seqs = sorted( | |
[ | |
d | |
for d in os.listdir(args.root_dir) | |
if os.path.isdir(os.path.join(args.root_dir, d)) | |
] | |
) | |
if not seqs: | |
raise ValueError(f"No valid sequence folders found in {args.root_dir}") | |
# Process each sequence in parallel | |
with ProcessPoolExecutor(max_workers=os.cpu_count() // 2) as executor: | |
future_to_seq = { | |
executor.submit( | |
process_sequence, | |
seq, | |
args.root_dir, | |
args.out_dir, | |
args.baseline, | |
args.output_size, | |
): seq | |
for seq in seqs | |
} | |
for future in tqdm( | |
as_completed(future_to_seq), | |
total=len(future_to_seq), | |
desc="Processing all sequences", | |
): | |
seq = future_to_seq[future] | |
error = future.result() | |
if error: | |
print(f"Sequence '{seq}' failed: {error}") | |
if __name__ == "__main__": | |
main() | |