Depth-Anything-ONNX
Depth-Anything-ONNX copied to clipboard
Metric depth inference
I'd like to try inference with a model trained on metric data. I've exported the DA NYU indoor pretrained model trained on metric indoor data to ONNX and run inference, as shown below, but the results indicate I'm doing something wrong.
I adapted the export script provided (see below) then invoke the script for ‘base’ model export via python .\export_metric.py --model_type "metric" --model b
.
The export process seemed to go fine, and checking the model didn't turn up any errors:
import onnx
from onnx import checker
model = onnx.load("depth_anything_vitb14_metric.onnx")
checker.check_model(model)
I can launch inference from a variant of the supplied eval script (also below). Using one of the provided .onnx models works as expected, e.g.: python .\batch_infer.py --img_dir .\input --model .\weights\depth_anything_vitb14.onnx --output_dir .\output
Below is an input image and output relative depth map:
I invoke inference with my exported model as follows: python .\batch_infer.py --img_dir .\input --model .\weights\depth_anything_vitb14_metric.onnx --output_dir .\output_metric
However, the depth results show tell-tale repeating block artifacts below:
'export_metric.py':
import argparse
import torch
from onnx import load_model, save_model
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from depth_anything.dpt import DPT_DINOv2
from depth_anything.util.transform import load_image
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
choices=["s", "b", "l"],
required=True,
help="Model size variant. Available options: 's', 'b', 'l'.",
)
parser.add_argument(
"--model_type",
type=str,
choices=["metric", "relative"],
required=True,
help="Model type. Available options: 'metric', 'relative'.",
)
parser.add_argument(
"--output",
type=str,
default=None,
required=False,
help="Path to save the ONNX model.",
)
return parser.parse_args()
def export_onnx(model: str, model_type: str, output: str = None):
# Handle args
if output is None:
output = f"weights/depth_anything_vit{model}14_{model_type}.onnx"
# Device for tracing
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Sample image for tracing
#image, _ = load_image("assets/sacre_coeur1.jpg")
image, _ = load_image("assets/grace.jpg")
image = torch.from_numpy(image).to(device)
# Initialize model instance based on model size
if model == "s":
depth_anything = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
elif model == "b":
depth_anything = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
else: # model == "l"
depth_anything = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
# Load checkpoint
if model_type == "metric":
checkpoint = torch.load("./metric_source/depth_anything_metric_depth_indoor.pt", map_location="cpu")
else: # model_type == "relative"
checkpoint = torch.hub.load_state_dict_from_url(
f"https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vit{model}14.pth",
map_location="cpu"
)
# Extract model weights from checkpoint
if "model" in checkpoint:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
# Load state dict into model
depth_anything.to(device).load_state_dict(state_dict, strict=False) # Consider using strict=False if necessary
depth_anything.eval()
# Proceed with ONNX export as before
torch.onnx.export(
depth_anything,
image,
output,
input_names=["image"],
output_names=["depth"],
opset_version=17,
dynamic_axes={
"image": {2: "height", 3: "width"},
"depth": {2: "height", 3: "width"},
},
)
# Shape inference for ONNX model
save_model(
SymbolicShapeInference.infer_shapes(load_model(output), auto_merge=True),
output,
)
if __name__ == "__main__":
args = parse_args()
export_onnx(**vars(args))
'batch_infer.py':
import argparse
import os
import time
import cv2
import numpy as np
import onnxruntime as ort
from depth_anything.util.transform import load_image
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--img_dir",
type=str,
required=True,
help="Path to input image directory.",
)
parser.add_argument(
"--model",
type=str,
required=True,
help="Path to ONNX model.",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory to save output depth images.",
)
parser.add_argument(
"--viz", action="store_true", help="Whether to visualize the results."
)
return parser.parse_args()
def infer(image_path: str, session, output_dir: str, viz: bool = False):
start_time = time.time()
image, (orig_h, orig_w) = load_image(image_path)
depth = session.run(None, {"image": image})[0]
depth = cv2.resize(depth[0, 0], (orig_w, orig_h))
depth_scaled = np.clip(depth * (65535 / depth.max()), 0, 65535).astype(np.uint16)
# Example for PNG format
output_path = os.path.join(output_dir, os.path.splitext(os.path.basename(image_path))[0] + ".png")
# Save grayscale depth image (without color map)
cv2.imwrite(output_path, depth_scaled)
end_time = time.time()
processing_time = end_time - start_time
print(f"Processed {image_path} in {processing_time:.2f} seconds")
return processing_time
def main():
args = parse_args()
# Ensure output directory exists
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
# Start model loading time measurement
start_model_loading = time.time()
# Load model
session = ort.InferenceSession(
args.model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
# End model loading time measurement
end_model_loading = time.time()
model_loading_time = end_model_loading - start_model_loading
print(f"Model loaded in {model_loading_time:.2f} seconds")
# Initialize variables for timing
total_time_excluding_model_loading = 0
image_count = 0
# Process each image in the directory
for filename in os.listdir(args.img_dir):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(args.img_dir, filename)
processing_time = infer(image_path, session, args.output_dir, args.viz)
total_time_excluding_model_loading += processing_time
image_count += 1
if image_count > 0:
average_time_excluding_model = total_time_excluding_model_loading / image_count
average_time_including_model = (total_time_excluding_model_loading + model_loading_time) / image_count
print(f"Average processing time per image (excluding model loading): {average_time_excluding_model:.2f} seconds")
print(f"Average processing time per image (including model loading): {average_time_including_model:.2f} seconds")
else:
print("No images were processed.")
if __name__ == "__main__":
main()