File size: 3,579 Bytes
5d472bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e26714
5d472bd
 
 
 
 
9e26714
 
5d472bd
 
9e26714
 
 
5d472bd
 
 
9e26714
5d472bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e26714
5d472bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import sys
sys.path.append('Depth-Anything-V2')

import cv2
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from depth_anything_v2.dpt import DepthAnythingV2
from pathlib import Path
from tqdm.auto import tqdm
import argparse


def parse_args():
    parser = argparse.ArgumentParser(description='Generate depth and normal maps from images')
    parser.add_argument('--source_root', type=str, default='test_dir', 
                        help='Root directory containing the images')
    parser.add_argument('--model_path', type=str, 
                        default='depth_anything_v2_vitl.pth',
                        help='Path to the depth model checkpoint')
    return parser.parse_args()


def generate_depth_maps(source_root, model_path):
    source_root = Path(source_root)
    origin = source_root / 'origin'
    to_depth_list = [origin]

    model = DepthAnythingV2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024]).cuda()
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    model.eval()

    depth_path = source_root / 'depth'
    depth_path.mkdir(parents=True, exist_ok=True)

    with torch.inference_mode():
        for to_depth_item in to_depth_list:
            folder_name = to_depth_item.stem
            dst_path = depth_path

            dst_path.mkdir(parents=True, exist_ok=True)
            
            bar = tqdm(to_depth_item.glob('*'))

            for image_path in bar:
                try:
                    raw_img = cv2.imread(str(image_path))
                    depth = model.infer_image(raw_img)  # HxW raw depth map

                    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
                    depth = depth.astype(np.uint8)

                    np.save(f'{dst_path}/{image_path.stem}.npy', depth)

                except Exception as e:
                    print(e)
                    continue
    
    return depth_path


def calculate_normal_map(img_path: Path, ksize=5):
    # 讀取深度圖
    depth = np.load(img_path).astype(np.float32)

    # 計算 X、Y 方向的梯度
    dx = cv2.Sobel(depth, cv2.CV_32F, 1, 0, ksize=ksize)
    dy = cv2.Sobel(depth, cv2.CV_32F, 0, 1, ksize=ksize)

    # 假設 Z 軸方向為 -1
    dz = np.ones_like(dx) * -1

    # 組合成法向量 (Nx, Ny, Nz)
    normals = np.stack((dx, dy, dz), axis=-1)

    # 進行歸一化
    norm = np.linalg.norm(normals, axis=-1, keepdims=True)
    normals /= (norm + 1e-6)  # 避免除零錯誤

    # 轉換為 0-255 的 RGB 影像 (HWC)
    normal_map = (normals + 1) / 2 * 255
    normal_map = normal_map.astype("uint8")

    normal_map = normal_map.transpose(2, 0, 1)  # (H, W, C) -> (C, H, W)

    return normal_map


def generate_normal_maps(source_root, ksize=5):
    source_root = Path(source_root)
    depth_root = source_root / 'depth'
    normal_root = source_root / 'normal'
    normal_root.mkdir(parents=True, exist_ok=True)

    bar = tqdm(list(depth_root.glob('*.npy')))

    for depth_img_path in bar:
        img_name = depth_img_path.name

        normal_map = calculate_normal_map(depth_img_path, ksize=ksize)

        np.save(f'{normal_root}/{img_name}', normal_map)


def main():
    args = parse_args()
    
    print(f"Generating depth maps from images in {args.source_root}")
    depth_path = generate_depth_maps(args.source_root, args.model_path)
    
    print(f"Generating normal maps from depth maps")
    generate_normal_maps(args.source_root)
    
    print("Processing complete!")


if __name__ == "__main__":
    main()