diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..9a717d468647f839ec11347da18ca2611b2ea889 --- /dev/null +++ b/app.py @@ -0,0 +1,124 @@ +import gradio as gr +from gradio_image_prompter import ImagePrompter +from detectron2.config import LazyConfig, instantiate +from detectron2.checkpoint import DetectionCheckpointer +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download + +DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' +model_choice = { + 'SAM': None, + 'HQ-SAM': None, + 'SAM2': None +} + +for model_type in model_choice.keys(): + model_choice[model_type] = hf_hub_download(repo_id="XiaRho/SEMat", filename=f"SEMat_{model_type}.pth", repo_type="model") + +def load_model(model_type='SAM2'): + assert model_type in model_choice.keys() + config_path = './configs/SEMat_{}.py'.format(model_type) + cfg = LazyConfig.load(config_path) + + if hasattr(cfg.model.sam_model, 'ckpt_path'): + cfg.model.sam_model.ckpt_path = None + else: + cfg.model.sam_model.checkpoint = None + model = instantiate(cfg.model) + if model.lora_rank is not None: + model.init_lora() + model.to(DEVICE) + DetectionCheckpointer(model).load(model_choice[model_type]) + model.eval() + return model, model_type + +def transform_image_bbox(prompts): + if len(prompts["points"]) != 1: + raise gr.Error("Please input only one BBox.", duration=5) + [[x1, y1, idx_3, x2, y2, idx_6]] = prompts["points"] + if idx_3 != 2 or idx_6 != 3: + raise gr.Error("Please input BBox instead of point.", duration=5) + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + + img = prompts["image"] + ori_H, ori_W, _ = img.shape + + scale = 1024 * 1.0 / max(ori_H, ori_W) + new_H, new_W = ori_H * scale, ori_W * scale + new_W = int(new_W + 0.5) + new_H = int(new_H + 0.5) + + img = cv2.resize(img, (new_W, new_H), interpolation=cv2.INTER_LINEAR) + padding = np.zeros([1024, 1024, 3], dtype=img.dtype) + padding[: new_H, : new_W, :] = img + img = padding + # img = img[:, :, ::-1].transpose((2, 0, 1)).astype(np.float32) / 255.0 + img = img.transpose((2, 0, 1)).astype(np.float32) / 255.0 + + [[x1, y1, _, x2, y2, _]] = prompts["points"] + x1, y1, x2, y2 = int(x1 * scale + 0.5), int(y1 * scale + 0.5), int(x2 * scale + 0.5), int(y2 * scale + 0.5) + bbox = np.clip(np.array([[x1, y1, x2, y2]]) * 1.0, 0, 1023.0) + + return img, bbox, (ori_H, ori_W), (new_H, new_W) + +if __name__ == '__main__': + + model, model_type = load_model() + + def inference_image(prompts, input_model_type): + + global model_type + global model + + if input_model_type != model_type: + gr.Info('Loading SEMat of {} version.'.format(input_model_type), duration=5) + _model, _ = load_model(input_model_type) + model_type = input_model_type + model = _model + + image, bbox, ori_H_W, pad_H_W = transform_image_bbox(prompts) + input_data = { + 'image': torch.from_numpy(image)[None].to(model.device), + 'bbox': torch.from_numpy(bbox)[None].to(model.device), + } + + with torch.no_grad(): + inputs = model.preprocess_inputs(input_data) + images, bbox, gt_alpha, trimap, condition = inputs['images'], inputs['bbox'], inputs['alpha'], inputs['trimap'], inputs['condition'] + + if model.backbone_condition: + condition_proj = model.condition_embedding(condition) + elif model.backbone_bbox_prompt is not None or model.bbox_prompt_all_block is not None: + condition_proj = bbox + else: + condition_proj = None + + low_res_masks, pred_alphas, pred_trimap, sam_hq_matting_token = model.forward_samhq_and_matting_decoder(images, bbox, condition_proj) + + + output_alpha = np.uint8(pred_alphas[0, 0][:pad_H_W[0], :pad_H_W[1], None].repeat(1, 1, 3).cpu().numpy() * 255) + + return output_alpha + + with gr.Blocks() as demo: + + with gr.Row(): + with gr.Column(scale=45): + img_in = ImagePrompter(type='numpy', show_label=False, label="query image") + + with gr.Column(scale=45): + img_out = gr.Image(type='pil', label="output") + + with gr.Row(): + with gr.Column(scale=45): + input_model_type = gr.Dropdown(list(model_choice.keys()), value='SAM2', label="Trained SEMat Version") + + with gr.Column(scale=45): + bt = gr.Button() + + bt.click(inference_image, inputs=[img_in, input_model_type], outputs=[img_out]) + +demo.launch() + diff --git a/configs/SEMat_HQ-SAM.py b/configs/SEMat_HQ-SAM.py new file mode 100644 index 0000000000000000000000000000000000000000..8a4a90d280dedac697f958cad1985dc14c4e73e7 --- /dev/null +++ b/configs/SEMat_HQ-SAM.py @@ -0,0 +1,48 @@ +from .common.train import train +from .semantic_enhanced_matting.model import model +from .common.optimizer import optimizer +from .common.scheduler import lr_multiplier +from .semantic_enhanced_matting.dataloader import dataloader +from modeling.decoder.unet_detail_capture import MattingDetailDecoder +from detectron2.config import LazyCall as L + +model.sam_model.model_type = 'vit_l' +model.sam_model.checkpoint = None +model.vis_period = 200 +model.output_dir = '?' + +train.max_iter = 60000 +train.eval_period = int(train.max_iter * 1 / 10) +train.checkpointer.period = int(train.max_iter * 1 / 10) +train.checkpointer.max_to_keep = 1 + +optimizer.lr = 5e-5 + +lr_multiplier.scheduler.values = [1.0, 0.5, 0.2] +lr_multiplier.scheduler.milestones = [0.5, 0.75] +lr_multiplier.scheduler.num_updates = train.max_iter +lr_multiplier.warmup_length = 250 / train.max_iter + +train.output_dir = './work_dirs/SEMat_HQ-SAM' + +model.lora_rank = 16 +model.lora_alpha = 16 +model.matting_decoder = L(MattingDetailDecoder)( + vit_intern_feat_in = 1024, + vit_intern_feat_index = [0, 1, 2, 3], + norm_type = 'SyncBN', + block_num = 2, + img_feat_in = 6, + norm_mask_logits = 6.5 +) +model.backbone_bbox_prompt = 'bbox' +model.backbone_bbox_prompt_loc = [2, 3] +model.backbone_bbox_prompt_loss_weight = 1.0 +model.matting_token = True +model.sam_model.matting_token = 3 +model.sam_model.frozen_decoder = True +model.sam_hq_token_reg = 0.2 +model.reg_w_bce_loss = True +model.matting_token_sup = 'trimap' +model.matting_token_sup_loss_weight = 0.05 +model.trimap_loss_type = 'NGHM' diff --git a/configs/SEMat_SAM.py b/configs/SEMat_SAM.py new file mode 100644 index 0000000000000000000000000000000000000000..c32de0be7ef2dee2b3c7f549039d1612c55309a5 --- /dev/null +++ b/configs/SEMat_SAM.py @@ -0,0 +1,51 @@ +from .common.train import train +from .semantic_enhanced_matting.model import model +from .common.optimizer import optimizer +from .common.scheduler import lr_multiplier +from .semantic_enhanced_matting.dataloader import dataloader +from modeling.decoder.unet_detail_capture import MattingDetailDecoder +from detectron2.config import LazyCall as L + +model.sam_model.model_type = 'vit_l' +model.sam_model.checkpoint = None +model.vis_period = 200 +model.output_dir = '?' + +train.max_iter = 60000 +train.eval_period = int(train.max_iter * 1 / 10) +train.checkpointer.period = int(train.max_iter * 1 / 10) +train.checkpointer.max_to_keep = 1 + +optimizer.lr = 5e-5 + +lr_multiplier.scheduler.values = [1.0, 0.5, 0.2] +lr_multiplier.scheduler.milestones = [0.5, 0.75] +lr_multiplier.scheduler.num_updates = train.max_iter +lr_multiplier.warmup_length = 250 / train.max_iter + +train.output_dir = './work_dirs/SEMat_SAM' + +model.lora_rank = 16 +model.lora_alpha = 16 +model.matting_decoder = L(MattingDetailDecoder)( + vit_intern_feat_in = 1024, + vit_intern_feat_index = [0, 1, 2, 3], + norm_type = 'SyncBN', + block_num = 2, + img_feat_in = 6, + norm_mask_logits = 6.5 +) +model.backbone_bbox_prompt = 'bbox' +model.backbone_bbox_prompt_loc = [2, 3] +model.backbone_bbox_prompt_loss_weight = 1.0 +model.matting_token = True +model.sam_model.matting_token = 3 +model.sam_model.frozen_decoder = True +model.sam_hq_token_reg = 0.2 +model.reg_on_sam_logits = True +model.reg_w_bce_loss = True +model.matting_token_sup = 'trimap' +model.matting_token_sup_loss_weight = 0.05 +model.trimap_loss_type = 'NGHM' +model.sam_model.wo_hq = True +model.sam_model.mask_matting_res_add = False diff --git a/configs/SEMat_SAM2.py b/configs/SEMat_SAM2.py new file mode 100644 index 0000000000000000000000000000000000000000..c6103a578349990e70cffce818464b9d6889dbea --- /dev/null +++ b/configs/SEMat_SAM2.py @@ -0,0 +1,57 @@ +from .common.train import train +from .semantic_enhanced_matting.model import model +from .common.optimizer import optimizer +from .common.scheduler import lr_multiplier +from .semantic_enhanced_matting.dataloader import dataloader +from modeling.decoder.unet_detail_capture import MattingDetailDecoder +from detectron2.config import LazyCall as L +from sam2.build_sam import build_sam2 + +model.sam_model.model_type = 'vit_l' +model.sam_model.checkpoint = None +model.vis_period = 200 +model.output_dir = '?' + +train.max_iter = 60000 +train.eval_period = int(train.max_iter * 1 / 10) +train.checkpointer.period = int(train.max_iter * 1 / 10) +train.checkpointer.max_to_keep = 1 + +optimizer.lr = 5e-5 + +lr_multiplier.scheduler.values = [1.0, 0.5, 0.2] +lr_multiplier.scheduler.milestones = [0.5, 0.75] +lr_multiplier.scheduler.num_updates = train.max_iter +lr_multiplier.warmup_length = 250 / train.max_iter + +train.output_dir = './work_dirs/SEMat_SAM2' + +model.sam2 = True +model.sam_model = L(build_sam2)( + config_file = 'sam2_hiera_l.yaml', + ckpt_path = None, + device = "cuda", + bbox_mask_matting_token = True, + mode="train", + upscaled_embedding_res_add = False +) +model.lora_rank = 16 +model.lora_alpha = 16 +model.matting_decoder = L(MattingDetailDecoder)( + vit_intern_feat_in = 1024, + vit_intern_feat_index = [0, 1, 2, 3], + norm_type = 'SyncBN', + block_num = 2, + img_feat_in = 6, + norm_mask_logits = 6.5, + sam2_multi_scale_feates = True +) +model.backbone_bbox_prompt = 'bbox' +model.backbone_bbox_prompt_loc = [2, 3] +model.backbone_bbox_prompt_loss_weight = 1.0 +model.matting_token = True +model.sam_hq_token_reg = 0.2 +model.reg_w_bce_loss = True +model.matting_token_sup = 'trimap' +model.matting_token_sup_loss_weight = 0.05 +model.trimap_loss_type = 'NGHM' diff --git a/configs/common/optimizer.py b/configs/common/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..42c4d5a5ca093fce2b7fc2578cac76ce2c7944ba --- /dev/null +++ b/configs/common/optimizer.py @@ -0,0 +1,26 @@ +from detectron2 import model_zoo +from functools import partial + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone"): + if ".pos_embed" in name or ".patch_embed" in name: + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + return lr_decay_rate ** (num_layers + 1 - layer_id) + +# Optimizer +optimizer = model_zoo.get_config("common/optim.py").AdamW +optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.65) +optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}} \ No newline at end of file diff --git a/configs/common/scheduler.py b/configs/common/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..0ceaf49576628f63b62c76fa46c3ce46c5ced308 --- /dev/null +++ b/configs/common/scheduler.py @@ -0,0 +1,13 @@ +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler +from fvcore.common.param_scheduler import MultiStepParamScheduler + +lr_multiplier = L(WarmupParamScheduler)( + scheduler=L(MultiStepParamScheduler)( + values=[1.0, 0.1, 0.01], + milestones=[96778, 103579], + num_updates=100, + ), + warmup_length=250 / 100, + warmup_factor=0.001, +) \ No newline at end of file diff --git a/configs/common/train.py b/configs/common/train.py new file mode 100644 index 0000000000000000000000000000000000000000..441786e4f0923e05bcc6be1ee992497836bb96d6 --- /dev/null +++ b/configs/common/train.py @@ -0,0 +1,17 @@ +train = dict( + output_dir="./output", + init_checkpoint="", + max_iter=90000, + amp=dict(enabled=False), # options for Automatic Mixed Precision + ddp=dict( # options for DistributedDataParallel + broadcast_buffers=True, + find_unused_parameters=False, + fp16_compression=True, + ), + checkpointer=dict(period=5000, max_to_keep=100), # options for PeriodicCheckpointer + eval_period=5000, + log_period=20, + device="cuda", + seed=42 + # ... +) \ No newline at end of file diff --git a/configs/semantic_enhanced_matting/dataloader.py b/configs/semantic_enhanced_matting/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..ba4247445876995001c7589c0616b79f63cb843f --- /dev/null +++ b/configs/semantic_enhanced_matting/dataloader.py @@ -0,0 +1,62 @@ +from omegaconf import OmegaConf +from torch.utils.data import ConcatDataset +from detectron2.config import LazyCall as L + +from data.dim_dataset import build_d2_test_dataloader, AdobeCompositionEvaluator, adobe_composition_collate_fn, RW100Test, AIM500Test, AM2KTest, P3M500Test, RWP636Test, SIMTest + +AIM500_PATH = '/path/to/datasets/AIM-500' +RW100_PATH = '/path/to/datasets/RefMatte_RW_100' +AM2K_PATH = '/path/to/datasets/AM-2K' +P3M500_PATH = '/path/to/datasets/P3M-10k/validation/P3M-500-NP' +RWP636_PATH = '/path/to/datasets/RealWorldPortrait-636' +SIM_PATH = '/path/to/datasets/SIMD/generated_testset' + +dataloader = OmegaConf.create() +test_dataset = L(ConcatDataset)( + datasets = [ + L(AIM500Test)( + data_dir = AIM500_PATH, + target_size = 1024, + multi_fg = True, + ), + L(RW100Test)( + data_dir = RW100_PATH, + target_size = 1024, + multi_fg = True, + ), + L(AM2KTest)( + data_dir = AM2K_PATH, + target_size = 1024, + multi_fg = True, + ), + L(P3M500Test)( + data_dir = P3M500_PATH, + target_size = 1024, + multi_fg = True, + ), + L(RWP636Test)( + data_dir = RWP636_PATH, + target_size = 1024, + multi_fg = True + ), + L(SIMTest)( + data_dir = SIM_PATH, + target_size = 1024, + multi_fg = True + ) + ] +) + +dataloader.test = L(build_d2_test_dataloader)( + dataset = test_dataset, + local_batch_size = 1, + num_workers = 4, + collate_fn = adobe_composition_collate_fn +) + +dataloader.evaluator = L(AdobeCompositionEvaluator)( + save_eval_results_step = 10, + output_dir = None, # modify in EvalHook (do_test) + eval_dataset_type = ['RW100', 'AIM500', 'AM2K', 'P3M500', 'RWP636', 'SIM'], + distributed = True, +), diff --git a/configs/semantic_enhanced_matting/model.py b/configs/semantic_enhanced_matting/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b6ab3f69d3570585f46756ed1f213374e3073d1c --- /dev/null +++ b/configs/semantic_enhanced_matting/model.py @@ -0,0 +1,35 @@ +from detectron2.config import LazyCall as L + +from modeling import Detail_Capture, MattingCriterion +from modeling.meta_arch import SamHqMatte +from modeling.semantic_enhanced_matting.build_sam import sam_model_registry_def +# from modeling.sam_hq_matting.predictor import SamPredictor +from modeling.semantic_enhanced_matting import MaskDecoderMatting + +mask_token_only = False + +model = L(SamHqMatte)( + + # original sam_hq + sam_model = L(sam_model_registry_def)( + model_type = 'vit_b', + checkpoint = None, + ), + hq_token_only = True, + hq_features_type = 'Final', + multimask_output = True, + + # loss function + criterion=L(MattingCriterion)( + losses = ['unknown_l1_loss', 'known_l1_loss', 'loss_pha_laplacian', 'loss_gradient_penalty'] + ), + + # other params. + pixel_mean = [123.675 / 255., 116.280 / 255., 103.530 / 255.], + pixel_std = [58.395 / 255., 57.120 / 255., 57.375 / 255.], + + lora_rank = None, + lora_alpha = None, + w_dora = False, + w_rslora = False, +) \ No newline at end of file diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a6d8b9c2967e4e7660868a4b5f2d1bf4e2aaa19 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1 @@ +from .dim_dataset import * \ No newline at end of file diff --git a/data/__pycache__/__init__.cpython-38.pyc b/data/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03cb664bb8b36d0b6dd6bc608eab66e303d7b450 Binary files /dev/null and b/data/__pycache__/__init__.cpython-38.pyc differ diff --git a/data/__pycache__/dim_dataset.cpython-38.pyc b/data/__pycache__/dim_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d7ebc3178ffa57ef88cc28c3beb8fa108bc51bb Binary files /dev/null and b/data/__pycache__/dim_dataset.cpython-38.pyc differ diff --git a/data/__pycache__/evaluate.cpython-38.pyc b/data/__pycache__/evaluate.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..290aebf059a454b8e79c21878c16ab7ed26bdba4 Binary files /dev/null and b/data/__pycache__/evaluate.cpython-38.pyc differ diff --git a/data/__pycache__/rand_augment.cpython-38.pyc b/data/__pycache__/rand_augment.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..254ab37896c285f4cedfa083a2b823ce6e726761 Binary files /dev/null and b/data/__pycache__/rand_augment.cpython-38.pyc differ diff --git a/data/coconut_dataset.py b/data/coconut_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..494f70266577fc1dbe063aa126af22c108c29292 --- /dev/null +++ b/data/coconut_dataset.py @@ -0,0 +1,377 @@ +import os +import time +import json +import torch +import numpy as np +import cv2 +from torch.utils.data import Dataset, DistributedSampler, Sampler +from torchvision import transforms +from detectron2.utils.logger import setup_logger +from typing import Optional +from operator import itemgetter +from collections import defaultdict + +from data.dim_dataset import GenBBox + + +def random_interp(): + return np.random.choice([cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]) + + +class SplitConcatImage(object): + + def __init__(self, concat_num=4, wo_mask_to_mattes=False): + self.concat_num = concat_num + self.wo_mask_to_mattes = wo_mask_to_mattes + if self.wo_mask_to_mattes: + assert self.concat_num == 5 + + def __call__(self, concat_image): + if isinstance(concat_image, list): + concat_image, image_path = concat_image[0], concat_image[1] + else: + image_path = None + H, W, _ = concat_image.shape + + concat_num = self.concat_num + if image_path is not None: + if '06-14' in image_path: + concat_num = 4 + elif 'ori_mask' in image_path or 'SEMat' in image_path: + concat_num = 3 + else: + concat_num = 5 + + assert W % concat_num == 0 + W = W // concat_num + + image = concat_image[:H, :W] + if self.concat_num != 3: + trimap = concat_image[:H, (concat_num - 2) * W: (concat_num - 1) * W] + if self.wo_mask_to_mattes: + alpha = concat_image[:H, 2 * W: 3 * W] + else: + alpha = concat_image[:H, (concat_num - 1) * W: concat_num * W] + else: + trimap = concat_image[:H, (concat_num - 1) * W: concat_num * W] + alpha = concat_image[:H, (concat_num - 2) * W: (concat_num - 1) * W] + + return {'image': image, 'trimap': trimap, 'alpha': alpha} + + +class RandomHorizontalFlip(object): + + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, sample): + if np.random.uniform(0, 1) < self.prob: + for key in sample.keys(): + sample[key] = cv2.flip(sample[key], 1) + return sample + +class EmptyAug(object): + def __call__(self, sample): + return sample + +class RandomReszieCrop(object): + + def __init__(self, output_size=1024, aug_scale_min=0.5, aug_scale_max=1.5): + self.desired_size = output_size + self.aug_scale_min = aug_scale_min + self.aug_scale_max = aug_scale_max + + def __call__(self, sample): + H, W, _ = sample['image'].shape + + if self.aug_scale_min == 1.0 and self.aug_scale_max == 1.0: + crop_H, crop_W = H, W + crop_y1, crop_y2 = 0, crop_H + crop_x1, crop_x2 = 0, crop_W + scale_W, scaled_H = W, H + elif self.aug_scale_min == -1.0 and self.aug_scale_max == -1.0: + scale = min(self.desired_size / H, self.desired_size / W) + scaled_H, scale_W = round(H * scale), round(W * scale) + crop_H, crop_W = scaled_H, scale_W + crop_y1, crop_y2 = 0, crop_H + crop_x1, crop_x2 = 0, crop_W + else: + # random size + random_scale = np.random.uniform(0, 1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min # random_val: 0.5 ~ 1.5 + scaled_size = round(random_scale * self.desired_size) + + scale = min(scaled_size / H, scaled_size / W) + scaled_H, scale_W = round(H * scale), round(W * scale) + + # random crop + crop_H, crop_W = min(self.desired_size, scaled_H), min(self.desired_size, scale_W) # crop_size + margin_H, margin_W = max(scaled_H - crop_H, 0), max(scale_W - crop_W, 0) + offset_H, offset_W = np.random.randint(0, margin_H + 1), np.random.randint(0, margin_W + 1) + crop_y1, crop_y2 = offset_H, offset_H + crop_H + crop_x1, crop_x2 = offset_W, offset_W + crop_W + + for key in sample.keys(): + sample[key] = cv2.resize(sample[key], (scale_W, scaled_H), interpolation=random_interp())[crop_y1: crop_y2, crop_x1: crop_x2, :] # resize and crop + padding = np.zeros(shape=(self.desired_size, self.desired_size, 3), dtype=sample[key].dtype) # pad to desired_size + padding[: crop_H, : crop_W, :] = sample[key] + sample[key] = padding + + return sample + + +class RandomJitter(object): + """ + Random change the hue of the image + """ + + def __call__(self, sample): + + image = sample['image'] + + # convert to HSV space, convert to float32 image to keep precision during space conversion. + image = cv2.cvtColor(image.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV) + # Hue noise + hue_jitter = np.random.randint(-40, 40) + image[:, :, 0] = np.remainder(image[:, :, 0].astype(np.float32) + hue_jitter, 360) + # Saturation noise + sat_bar = image[:, :, 1].mean() + + sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10 + sat = image[:, :, 1] + sat = np.abs(sat + sat_jitter) + sat[sat>1] = 2 - sat[sat>1] + image[:, :, 1] = sat + # Value noise + val_bar = image[:, :, 2].mean() + + val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10 + val = image[:, :, 2] + val = np.abs(val + val_jitter) + val[val>1] = 2 - val[val>1] + image[:, :, 2] = val + # convert back to BGR space + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) + sample['image'] = image * 255 + + return sample + + +class ToTensor(object): + + def __call__(self, sample): + image, alpha, trimap = sample['image'][:, :, ::-1], sample['alpha'], sample['trimap'] + + # image + image = image.transpose((2, 0, 1)) / 255. + sample['image'] = torch.from_numpy(image).float() + + # alpha + alpha = alpha.transpose((2, 0, 1))[0: 1] / 255. + alpha[alpha < 0 ] = 0 + alpha[alpha > 1] = 1 + sample['alpha'] = torch.from_numpy(alpha).float() + + # trimap + trimap = trimap.transpose((2, 0, 1))[0: 1] / 1. + sample['trimap'] = torch.from_numpy(trimap).float() + sample['trimap'][sample['trimap'] < 85] = 0 + sample['trimap'][sample['trimap'] >= 170] = 1 + sample['trimap'][sample['trimap'] >= 85] = 0.5 + + return sample + + +class COCONutData(Dataset): + def __init__( + self, + json_path, + data_root_path, + output_size = 512, + aug_scale_min = 0.5, + aug_scale_max = 1.5, + with_bbox = False, + bbox_offset_factor = None, + phase = "train", + min_miou = 95, + miou_json = '', + remove_coco_transparent = False, + coconut_num_ratio = None, + return_multi_fg_info = False, + wo_accessory_fusion = False, + wo_mask_to_mattes = False, + return_image_name = False, + ): + + self.data_root_path = data_root_path + self.output_size = output_size + self.aug_scale_min = aug_scale_min + self.aug_scale_max = aug_scale_max + self.with_bbox = with_bbox + self.bbox_offset_factor = bbox_offset_factor + self.phase = phase + self.min_miou = min_miou + self.miou_json = miou_json + self.remove_coco_transparent = remove_coco_transparent + self.coconut_num_ratio = coconut_num_ratio + self.return_multi_fg_info = return_multi_fg_info + self.wo_accessory_fusion = wo_accessory_fusion # TODO + self.wo_mask_to_mattes = wo_mask_to_mattes + self.return_image_name = return_image_name + assert self.wo_accessory_fusion + self.wo_mask_to_mattes <= 1 + assert self.phase == 'train' + + self.data_path = [] + with open(json_path, "r") as file: + coconut_matting_info = json.load(file) + + if self.miou_json != '': + name_2_miou_dict = defaultdict(int) + with open(self.miou_json, "r") as file: + coconut_matting_miou = json.load(file) + for miou, name in coconut_matting_miou: + name_2_miou_dict[name] = miou + for i in coconut_matting_info: + if 'accessory' in i['save_path']: + self.data_path.append(i['save_path']) + elif name_2_miou_dict[i['save_path'].split('/')[-1]] >= self.min_miou: + if not (self.remove_coco_transparent and 'glass' in i['save_path']): + self.data_path.append(i['save_path']) + else: + for i in coconut_matting_info: + self.data_path.append(i['save_path']) + + if 'accessory' in json_path: + concat_num = 5 + elif 'ori_mask' in json_path: + concat_num = 3 + else: + concat_num = 4 + + train_trans = [ + SplitConcatImage(concat_num, wo_mask_to_mattes = self.wo_mask_to_mattes), + RandomHorizontalFlip(prob=0 if hasattr(self, 'return_image_name') and self.return_image_name else 0.5), + RandomReszieCrop(self.output_size, self.aug_scale_min, self.aug_scale_max), + EmptyAug() if hasattr(self, 'return_image_name') and self.return_image_name else RandomJitter(), + ToTensor(), + GenBBox(bbox_offset_factor=self.bbox_offset_factor) + ] + self.transform = transforms.Compose(train_trans) + print('coconut num: ', len(self.data_path) * self.coconut_num_ratio if self.coconut_num_ratio is not None else len(self.data_path)) + + def __getitem__(self, idx): + if self.coconut_num_ratio is not None: + if self.coconut_num_ratio < 1.0 or idx >= len(self.data_path): + idx = np.random.randint(0, len(self.data_path)) + concat_image = cv2.imread(os.path.join(self.data_root_path, self.data_path[idx])) + sample = self.transform([concat_image, self.data_path[idx]]) + sample['dataset_name'] = 'COCONut' + if self.return_multi_fg_info: + sample['multi_fg'] = False + if hasattr(self, 'return_image_name') and self.return_image_name: + sample['image_name'] = self.data_path[idx] + return sample + + def __len__(self): + if self.coconut_num_ratio is not None: + return int(len(self.data_path) * self.coconut_num_ratio) + else: + return len(self.data_path) + + +class DatasetFromSampler(Dataset): + """Dataset to create indexes from `Sampler`. + + Args: + sampler: PyTorch sampler + """ + + def __init__(self, sampler: Sampler): + """Initialisation for DatasetFromSampler.""" + self.sampler = sampler + self.sampler_list = None + + def __getitem__(self, index: int): + """Gets element of the dataset. + + Args: + index: index of the element in the dataset + + Returns: + Single element by index + """ + if self.sampler_list is None: + self.sampler_list = list(self.sampler) + return self.sampler_list[index] + + def __len__(self) -> int: + """ + Returns: + int: length of the dataset + """ + return len(self.sampler) + + +class DistributedSamplerWrapper(DistributedSampler): + """ + Wrapper over `Sampler` for distributed training. + Allows you to use any sampler in distributed mode. + It is especially useful in conjunction with + `torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSamplerWrapper instance as a DataLoader + sampler, and load a subset of subsampled data of the original dataset + that is exclusive to it. + .. note:: + Sampler is assumed to be of constant size. + """ + + def __init__( + self, + sampler, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + """ + Args: + sampler: Sampler used for subsampling + num_replicas (int, optional): Number of processes participating in + distributed training + rank (int, optional): Rank of the current process + within ``num_replicas`` + shuffle (bool, optional): If true (default), + sampler will shuffle the indices + """ + super(DistributedSamplerWrapper, self).__init__( + DatasetFromSampler(sampler), + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) + self.sampler = sampler + + def __iter__(self): + """@TODO: Docs. Contribution is welcome.""" + self.dataset = DatasetFromSampler(self.sampler) + indexes_of_indexes = super().__iter__() + subsampler_indexes = self.dataset + return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) + + +if __name__ == '__main__': + + + + dataset = COCONutData( + json_path = '/root/data/my_path/Matting/DiffMatte-main/24-06-14_coco-nut_matting.json', + data_root_path = '/root/data/my_path/Matting/DiffMatte-main', + output_size = 1024, + aug_scale_min = 0.5, + aug_scale_max = 1.5, + with_bbox = True, + bbox_offset_factor = 0.1, + phase = "train" + ) + data = dataset[0] + + for key, val in data.items(): + print(key, val.shape, torch.min(val), torch.max(val)) \ No newline at end of file diff --git a/data/dim_dataset.py b/data/dim_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6f47549bb6e253201616934c158c6037d5893a01 --- /dev/null +++ b/data/dim_dataset.py @@ -0,0 +1,1476 @@ +''' +Dataloader to process Adobe Image Matting Dataset. + +From GCA_Matting(https://github.com/Yaoyi-Li/GCA-Matting/tree/master/dataloader) +''' +import os +import glob +import logging +import os.path as osp +import functools +import numpy as np +import torch +import cv2 +import math +import numbers +import random +import pickle +from torch.utils.data import Dataset, DataLoader +from torch.nn import functional as F +from torchvision import transforms +from easydict import EasyDict +from detectron2.utils.logger import setup_logger +from detectron2.utils import comm +from detectron2.data import build_detection_test_loader +import torchvision.transforms.functional + +import json +from PIL import Image +from detectron2.evaluation.evaluator import DatasetEvaluator +from collections import defaultdict + +from data.evaluate import compute_sad_loss, compute_mse_loss, compute_mad_loss, compute_gradient_loss, compute_connectivity_error + +# Base default config +CONFIG = EasyDict({}) + +# Model config +CONFIG.model = EasyDict({}) +# one-hot or class, choice: [3, 1] +CONFIG.model.trimap_channel = 1 + +# Dataloader config +CONFIG.data = EasyDict({}) +# feed forward image size (untested) +CONFIG.data.crop_size = 512 +# composition of two foregrounds, affine transform, crop and HSV jitter +CONFIG.data.cutmask_prob = 0.25 +CONFIG.data.augmentation = True +CONFIG.data.random_interp = True + +class Prefetcher(): + """ + Modified from the data_prefetcher in https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py + """ + def __init__(self, loader): + self.orig_loader = loader + self.stream = torch.cuda.Stream() + self.next_sample = None + + def preload(self): + try: + self.next_sample = next(self.loader) + except StopIteration: + self.next_sample = None + return + + with torch.cuda.stream(self.stream): + for key, value in self.next_sample.items(): + if isinstance(value, torch.Tensor): + self.next_sample[key] = value.cuda(non_blocking=True) + + def __next__(self): + torch.cuda.current_stream().wait_stream(self.stream) + sample = self.next_sample + if sample is not None: + for key, value in sample.items(): + if isinstance(value, torch.Tensor): + sample[key].record_stream(torch.cuda.current_stream()) + self.preload() + else: + # throw stop exception if there is no more data to perform as a default dataloader + raise StopIteration("No samples in loader. example: `iterator = iter(Prefetcher(loader)); " + "data = next(iterator)`") + return sample + + def __iter__(self): + self.loader = iter(self.orig_loader) + self.preload() + return self + + +class ImageFile(object): + def __init__(self, phase='train'): + self.phase = phase + self.rng = np.random.RandomState(0) + + def _get_valid_names(self, *dirs, shuffle=True): + name_sets = [self._get_name_set(d) for d in dirs] + + def _join_and(a, b): + return a & b + + valid_names = list(functools.reduce(_join_and, name_sets)) + if shuffle: + self.rng.shuffle(valid_names) + + return valid_names + + @staticmethod + def _get_name_set(dir_name): + path_list = glob.glob(os.path.join(dir_name, '*')) + name_set = set() + for path in path_list: + name = os.path.basename(path) + name = os.path.splitext(name)[0] + name_set.add(name) + return name_set + + @staticmethod + def _list_abspath(data_dir, ext, data_list): + return [os.path.join(data_dir, name + ext) + for name in data_list] + +class ImageFileTrain(ImageFile): + def __init__( + self, + alpha_dir="train_alpha", + fg_dir="train_fg", + bg_dir="train_bg", + alpha_ext=".jpg", + fg_ext=".jpg", + bg_ext=".jpg", + fg_have_bg_num=None, + alpha_ratio_json = None, + alpha_min_ratio = None, + key_sample_ratio = None, + ): + super(ImageFileTrain, self).__init__(phase="train") + + self.alpha_dir = alpha_dir + self.fg_dir = fg_dir + self.bg_dir = bg_dir + self.alpha_ext = alpha_ext + self.fg_ext = fg_ext + self.bg_ext = bg_ext + logger = setup_logger(name=__name__) + + if not isinstance(self.alpha_dir, str): + assert len(self.alpha_dir) == len(self.fg_dir) == len(alpha_ext) == len(fg_ext) + self.valid_fg_list = [] + self.alpha = [] + self.fg = [] + self.key_alpha = [] + self.key_fg = [] + for i in range(len(self.alpha_dir)): + valid_fg_list = self._get_valid_names(self.fg_dir[i], self.alpha_dir[i]) + valid_fg_list.sort() + alpha = self._list_abspath(self.alpha_dir[i], self.alpha_ext[i], valid_fg_list) + fg = self._list_abspath(self.fg_dir[i], self.fg_ext[i], valid_fg_list) + self.valid_fg_list += valid_fg_list + + self.alpha += alpha * fg_have_bg_num[i] + self.fg += fg * fg_have_bg_num[i] + + if alpha_ratio_json[i] is not None: + tmp_key_alpha = [] + tmp_key_fg = [] + name_to_alpha_path = dict() + for name in alpha: + name_to_alpha_path[name.split('/')[-1].split('.')[0]] = name + name_to_fg_path = dict() + for name in fg: + name_to_fg_path[name.split('/')[-1].split('.')[0]] = name + + with open(alpha_ratio_json[i], 'r') as file: + alpha_ratio_list = json.load(file) + for ratio, name in alpha_ratio_list: + if ratio < alpha_min_ratio[i]: + break + tmp_key_alpha.append(name_to_alpha_path[name.split('.')[0]]) + tmp_key_fg.append(name_to_fg_path[name.split('.')[0]]) + + self.key_alpha.extend(tmp_key_alpha * fg_have_bg_num[i]) + self.key_fg.extend(tmp_key_fg * fg_have_bg_num[i]) + + if len(self.key_alpha) != 0 and key_sample_ratio > 0: + repeat_num = key_sample_ratio * (len(self.alpha) - len(self.key_alpha)) / len(self.key_alpha) / (1 - key_sample_ratio) - 1 + print('key sample num:', len(self.key_alpha), ', repeat num: ', repeat_num) + for i in range(math.ceil(repeat_num)): + self.alpha += self.key_alpha + self.fg += self.key_fg + + else: + self.valid_fg_list = self._get_valid_names(self.fg_dir, self.alpha_dir) + self.valid_fg_list.sort() + self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_fg_list) + self.fg = self._list_abspath(self.fg_dir, self.fg_ext, self.valid_fg_list) + + self.valid_bg_list = [os.path.splitext(name)[0] for name in os.listdir(self.bg_dir)] + self.valid_bg_list.sort() + + if fg_have_bg_num is not None: + # assert fg_have_bg_num * len(self.valid_fg_list) <= len(self.valid_bg_list) + # self.valid_bg_list = self.valid_bg_list[: fg_have_bg_num * len(self.valid_fg_list)] + assert len(self.alpha) <= len(self.valid_bg_list) + self.valid_bg_list = self.valid_bg_list[: len(self.alpha)] + + self.bg = self._list_abspath(self.bg_dir, self.bg_ext, self.valid_bg_list) + + def __len__(self): + return len(self.alpha) + +class ImageFileTest(ImageFile): + def __init__(self, + alpha_dir="test_alpha", + merged_dir="test_merged", + trimap_dir="test_trimap", + alpha_ext=".png", + merged_ext=".png", + trimap_ext=".png"): + super(ImageFileTest, self).__init__(phase="test") + + self.alpha_dir = alpha_dir + self.merged_dir = merged_dir + self.trimap_dir = trimap_dir + self.alpha_ext = alpha_ext + self.merged_ext = merged_ext + self.trimap_ext = trimap_ext + + self.valid_image_list = self._get_valid_names(self.alpha_dir, self.merged_dir, self.trimap_dir, shuffle=False) + + self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_image_list) + self.merged = self._list_abspath(self.merged_dir, self.merged_ext, self.valid_image_list) + self.trimap = self._list_abspath(self.trimap_dir, self.trimap_ext, self.valid_image_list) + + def __len__(self): + return len(self.alpha) + +interp_list = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4] + + +def maybe_random_interp(cv2_interp): + if CONFIG.data.random_interp: + return np.random.choice(interp_list) + else: + return cv2_interp + + +class ToTensor(object): + """ + Convert ndarrays in sample to Tensors with normalization. + """ + def __init__(self, phase="test"): + self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1) + self.std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) + self.phase = phase + + def __call__(self, sample): + image, alpha, trimap, mask = sample['image'][:,:,::-1], sample['alpha'], sample['trimap'], sample['mask'] + + alpha[alpha < 0 ] = 0 + alpha[alpha > 1] = 1 + + image = image.transpose((2, 0, 1)).astype(np.float32) + alpha = np.expand_dims(alpha.astype(np.float32), axis=0) + + mask = np.expand_dims(mask.astype(np.float32), axis=0) + + image /= 255. + + if self.phase == "train": + fg = sample['fg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255. + sample['fg'] = torch.from_numpy(fg) + bg = sample['bg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255. + sample['bg'] = torch.from_numpy(bg) + + sample['image'], sample['alpha'], sample['trimap'] = \ + torch.from_numpy(image), torch.from_numpy(alpha), torch.from_numpy(trimap).to(torch.long) + sample['image'] = sample['image'] + + if CONFIG.model.trimap_channel == 3: + sample['trimap'] = F.one_hot(sample['trimap'], num_classes=3).permute(2,0,1).float() + elif CONFIG.model.trimap_channel == 1: + sample['trimap'] = sample['trimap'][None,...].float() + else: + raise NotImplementedError("CONFIG.model.trimap_channel can only be 3 or 1") + sample['trimap'][sample['trimap'] < 85] = 0 + sample['trimap'][sample['trimap'] >= 170] = 1 + sample['trimap'][sample['trimap'] >= 85] = 0.5 + + sample['mask'] = torch.from_numpy(mask).float() + + return sample + + +class RandomAffine(object): + """ + Random affine translation + """ + def __init__(self, degrees, translate=None, scale=None, shear=None, flip=None, resample=False, fillcolor=0): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError("If degrees is a single number, it must be positive.") + self.degrees = (-degrees, degrees) + else: + assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ + "degrees should be a list or tuple and it must be of length 2." + self.degrees = degrees + + if translate is not None: + assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ + "translate should be a list or tuple and it must be of length 2." + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + + if scale is not None: + assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ + "scale should be a list or tuple and it must be of length 2." + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + if isinstance(shear, numbers.Number): + if shear < 0: + raise ValueError("If shear is a single number, it must be positive.") + self.shear = (-shear, shear) + else: + assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ + "shear should be a list or tuple and it must be of length 2." + self.shear = shear + else: + self.shear = shear + + self.resample = resample + self.fillcolor = fillcolor + self.flip = flip + + @staticmethod + def get_params(degrees, translate, scale_ranges, shears, flip, img_size): + """Get parameters for affine transformation + + Returns: + sequence: params to be passed to the affine transformation + """ + angle = random.uniform(degrees[0], degrees[1]) + if translate is not None: + max_dx = translate[0] * img_size[0] + max_dy = translate[1] * img_size[1] + translations = (np.round(random.uniform(-max_dx, max_dx)), + np.round(random.uniform(-max_dy, max_dy))) + else: + translations = (0, 0) + + if scale_ranges is not None: + scale = (random.uniform(scale_ranges[0], scale_ranges[1]), + random.uniform(scale_ranges[0], scale_ranges[1])) + else: + scale = (1.0, 1.0) + + if shears is not None: + shear = random.uniform(shears[0], shears[1]) + else: + shear = 0.0 + + if flip is not None: + flip = (np.random.rand(2) < flip).astype(np.int32) * 2 - 1 + + return angle, translations, scale, shear, flip + + def __call__(self, sample): + fg, alpha = sample['fg'], sample['alpha'] + rows, cols, ch = fg.shape + if np.maximum(rows, cols) < 1024: + params = self.get_params((0, 0), self.translate, self.scale, self.shear, self.flip, fg.size) + else: + params = self.get_params(self.degrees, self.translate, self.scale, self.shear, self.flip, fg.size) + + center = (cols * 0.5 + 0.5, rows * 0.5 + 0.5) + M = self._get_inverse_affine_matrix(center, *params) + M = np.array(M).reshape((2, 3)) + + fg = cv2.warpAffine(fg, M, (cols, rows), + flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP) + alpha = cv2.warpAffine(alpha, M, (cols, rows), + flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP) + + sample['fg'], sample['alpha'] = fg, alpha + + return sample + + + @ staticmethod + def _get_inverse_affine_matrix(center, angle, translate, scale, shear, flip): + + angle = math.radians(angle) + shear = math.radians(shear) + scale_x = 1.0 / scale[0] * flip[0] + scale_y = 1.0 / scale[1] * flip[1] + + # Inverted rotation matrix with scale and shear + d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle) + matrix = [ + math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0, + -math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0 + ] + matrix = [m / d for m in matrix] + + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1]) + matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1]) + + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + matrix[2] += center[0] + matrix[5] += center[1] + + return matrix + + +class RandomJitter(object): + """ + Random change the hue of the image + """ + + def __call__(self, sample): + sample_ori = sample.copy() + fg, alpha = sample['fg'], sample['alpha'] + # if alpha is all 0 skip + if np.all(alpha==0): + return sample_ori + # convert to HSV space, convert to float32 image to keep precision during space conversion. + fg = cv2.cvtColor(fg.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV) + # Hue noise + hue_jitter = np.random.randint(-40, 40) + fg[:, :, 0] = np.remainder(fg[:, :, 0].astype(np.float32) + hue_jitter, 360) + # Saturation noise + sat_bar = fg[:, :, 1][alpha > 0].mean() + if np.isnan(sat_bar): + return sample_ori + sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10 + sat = fg[:, :, 1] + sat = np.abs(sat + sat_jitter) + sat[sat>1] = 2 - sat[sat>1] + fg[:, :, 1] = sat + # Value noise + val_bar = fg[:, :, 2][alpha > 0].mean() + if np.isnan(val_bar): + return sample_ori + val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10 + val = fg[:, :, 2] + val = np.abs(val + val_jitter) + val[val>1] = 2 - val[val>1] + fg[:, :, 2] = val + # convert back to BGR space + fg = cv2.cvtColor(fg, cv2.COLOR_HSV2BGR) + sample['fg'] = fg*255 + + return sample + + +class RandomHorizontalFlip(object): + """ + Random flip image and label horizontally + """ + def __init__(self, prob=0.5): + self.prob = prob + def __call__(self, sample): + fg, alpha = sample['fg'], sample['alpha'] + if np.random.uniform(0, 1) < self.prob: + fg = cv2.flip(fg, 1) + alpha = cv2.flip(alpha, 1) + sample['fg'], sample['alpha'] = fg, alpha + + return sample + + +class RandomCrop(object): + """ + Crop randomly the image in a sample, retain the center 1/4 images, and resize to 'output_size' + + :param output_size (tuple or int): Desired output size. If int, square crop + is made. + """ + + def __init__(self, output_size=( CONFIG.data.crop_size, CONFIG.data.crop_size)): + assert isinstance(output_size, (int, tuple)) + if isinstance(output_size, int): + self.output_size = (output_size, output_size) + else: + assert len(output_size) == 2 + self.output_size = output_size + self.margin = output_size[0] // 2 + self.logger = logging.getLogger("Logger") + + def __call__(self, sample): + fg, alpha, trimap, mask, name = sample['fg'], sample['alpha'], sample['trimap'], sample['mask'], sample['image_name'] + bg = sample['bg'] + h, w = trimap.shape + bg = cv2.resize(bg, (w, h), interpolation=maybe_random_interp(cv2.INTER_CUBIC)) + if w < self.output_size[0]+1 or h < self.output_size[1]+1: + ratio = 1.1*self.output_size[0]/h if h < w else 1.1*self.output_size[1]/w + # self.logger.warning("Size of {} is {}.".format(name, (h, w))) + while h < self.output_size[0]+1 or w < self.output_size[1]+1: + fg = cv2.resize(fg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) + alpha = cv2.resize(alpha, (int(w*ratio), int(h*ratio)), + interpolation=maybe_random_interp(cv2.INTER_NEAREST)) + trimap = cv2.resize(trimap, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST) + bg = cv2.resize(bg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_CUBIC)) + mask = cv2.resize(mask, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST) + h, w = trimap.shape + small_trimap = cv2.resize(trimap, (w//4, h//4), interpolation=cv2.INTER_NEAREST) + unknown_list = list(zip(*np.where(small_trimap[self.margin//4:(h-self.margin)//4, + self.margin//4:(w-self.margin)//4] == 128))) + unknown_num = len(unknown_list) + if len(unknown_list) < 10: + left_top = (np.random.randint(0, h-self.output_size[0]+1), np.random.randint(0, w-self.output_size[1]+1)) + else: + idx = np.random.randint(unknown_num) + left_top = (unknown_list[idx][0]*4, unknown_list[idx][1]*4) + + fg_crop = fg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:] + alpha_crop = alpha[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]] + bg_crop = bg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:] + trimap_crop = trimap[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]] + mask_crop = mask[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]] + + if len(np.where(trimap==128)[0]) == 0: + self.logger.error("{} does not have enough unknown area for crop. Resized to target size." + "left_top: {}".format(name, left_top)) + fg_crop = cv2.resize(fg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST)) + alpha_crop = cv2.resize(alpha, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST)) + trimap_crop = cv2.resize(trimap, self.output_size[::-1], interpolation=cv2.INTER_NEAREST) + bg_crop = cv2.resize(bg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_CUBIC)) + mask_crop = cv2.resize(mask, self.output_size[::-1], interpolation=cv2.INTER_NEAREST) + + sample.update({'fg': fg_crop, 'alpha': alpha_crop, 'trimap': trimap_crop, 'mask': mask_crop, 'bg': bg_crop}) + return sample + + +class OriginScale(object): + def __call__(self, sample): + h, w = sample["alpha_shape"] + + if h % 32 == 0 and w % 32 == 0: + return sample + + target_h = 32 * ((h - 1) // 32 + 1) + target_w = 32 * ((w - 1) // 32 + 1) + pad_h = target_h - h + pad_w = target_w - w + + padded_image = np.pad(sample['image'], ((0,pad_h), (0, pad_w), (0,0)), mode="reflect") + padded_trimap = np.pad(sample['trimap'], ((0,pad_h), (0, pad_w)), mode="reflect") + padded_mask = np.pad(sample['mask'], ((0,pad_h), (0, pad_w)), mode="reflect") + + sample['image'] = padded_image + sample['trimap'] = padded_trimap + sample['mask'] = padded_mask + + return sample + + +class GenMask(object): + def __init__(self): + self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,30)] + + def __call__(self, sample): + alpha_ori = sample['alpha'] + h, w = alpha_ori.shape + + max_kernel_size = 30 + alpha = cv2.resize(alpha_ori, (640,640), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) + + ### generate trimap + fg_mask = (alpha + 1e-5).astype(np.int32).astype(np.uint8) + bg_mask = (1 - alpha + 1e-5).astype(np.int32).astype(np.uint8) + fg_mask = cv2.erode(fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) + bg_mask = cv2.erode(bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) + + fg_width = np.random.randint(1, 30) + bg_width = np.random.randint(1, 30) + fg_mask = (alpha + 1e-5).astype(np.int32).astype(np.uint8) + bg_mask = (1 - alpha + 1e-5).astype(np.int32).astype(np.uint8) + fg_mask = cv2.erode(fg_mask, self.erosion_kernels[fg_width]) + bg_mask = cv2.erode(bg_mask, self.erosion_kernels[bg_width]) + + trimap = np.ones_like(alpha) * 128 + trimap[fg_mask == 1] = 255 + trimap[bg_mask == 1] = 0 + + trimap = cv2.resize(trimap, (w,h), interpolation=cv2.INTER_NEAREST) + sample['trimap'] = trimap + + ### generate mask + low = 0.01 + high = 1.0 + thres = random.random() * (high - low) + low + seg_mask = (alpha >= thres).astype(np.int32).astype(np.uint8) + random_num = random.randint(0,3) + if random_num == 0: + seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) + elif random_num == 1: + seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) + elif random_num == 2: + seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) + seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) + elif random_num == 3: + seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) + seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) + + seg_mask = cv2.resize(seg_mask, (w,h), interpolation=cv2.INTER_NEAREST) + sample['mask'] = seg_mask + + return sample + + +class Composite(object): + def __call__(self, sample): + fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha'] + alpha[alpha < 0 ] = 0 + alpha[alpha > 1] = 1 + fg[fg < 0 ] = 0 + fg[fg > 255] = 255 + bg[bg < 0 ] = 0 + bg[bg > 255] = 255 + + image = fg * alpha[:, :, None] + bg * (1 - alpha[:, :, None]) + sample['image'] = image + return sample + + +class CutMask(object): + def __init__(self, perturb_prob = 0): + self.perturb_prob = perturb_prob + + def __call__(self, sample): + if np.random.rand() < self.perturb_prob: + return sample + + mask = sample['mask'] # H x W, trimap 0--255, segmask 0--1, alpha 0--1 + h, w = mask.shape + perturb_size_h, perturb_size_w = random.randint(h // 4, h // 2), random.randint(w // 4, w // 2) + x = random.randint(0, h - perturb_size_h) + y = random.randint(0, w - perturb_size_w) + x1 = random.randint(0, h - perturb_size_h) + y1 = random.randint(0, w - perturb_size_w) + + mask[x:x+perturb_size_h, y:y+perturb_size_w] = mask[x1:x1+perturb_size_h, y1:y1+perturb_size_w].copy() + + sample['mask'] = mask + return sample + + +class ScaleFg(object): + def __init__(self, min_scale_fg_scale=0.5, max_scale_fg_scale=1.0): + self.min_scale_fg_scale = min_scale_fg_scale + self.max_scale_fg_scale = max_scale_fg_scale + + def __call__(self, sample): + scale_factor = np.random.uniform(low=self.min_scale_fg_scale, high=self.max_scale_fg_scale) + + fg, alpha = sample['fg'], sample['alpha'] # np.array(): [H, W, 3] 0 ~ 255 , [H, W] 0.0 ~ 1.0 + h, w = alpha.shape + scale_h, scale_w = int(h * scale_factor), int(w * scale_factor) + + new_fg, new_alpha = np.zeros_like(fg), np.zeros_like(alpha) + fg = cv2.resize(fg, (scale_w, scale_h), interpolation=cv2.INTER_LINEAR) + alpha = cv2.resize(alpha, (scale_w, scale_h), interpolation=cv2.INTER_LINEAR) + + if scale_factor <= 1: + offset_h, offset_w = np.random.randint(h - scale_h + 1), np.random.randint(w - scale_w + 1) + new_fg[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w, :] = fg + new_alpha[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w] = alpha + else: + offset_h, offset_w = np.random.randint(scale_h - h + 1), np.random.randint(scale_w - w + 1) + new_fg = fg[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w, :] + new_alpha = alpha[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w] + + sample['fg'], sample['alpha'] = new_fg, new_alpha + return sample + +class GenBBox(object): + def __init__(self, bbox_offset_factor = 0.1, random_crop_bbox = None, train_or_test = 'train', dataset_type = None, random_auto_matting=None): + self.bbox_offset_factor = bbox_offset_factor + self.random_crop_bbox = random_crop_bbox + self.train_or_test = train_or_test + self.dataset_type = dataset_type + self.random_auto_matting = random_auto_matting + + def __call__(self, sample): + + alpha = sample['alpha'] # [1, H, W] 0.0 ~ 1.0 + indices = torch.nonzero(alpha[0], as_tuple=True) + + if len(indices[0]) > 0: + + min_x, min_y = torch.min(indices[1]), torch.min(indices[0]) + max_x, max_y = torch.max(indices[1]), torch.max(indices[0]) + + if self.random_crop_bbox is not None and np.random.uniform(0, 1) < self.random_crop_bbox: + ori_h_w = (sample['alpha'].shape[-2], sample['alpha'].shape[-1]) + sample['alpha'] = F.interpolate(sample['alpha'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0] + sample['image'] = F.interpolate(sample['image'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0] + sample['trimap'] = F.interpolate(sample['trimap'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='nearest')[0] + bbox = torch.tensor([[0, 0, ori_h_w[1] - 1, ori_h_w[0] - 1]]) + + elif self.bbox_offset_factor != 0: + bbox_w = max(1, max_x - min_x) + bbox_h = max(1, max_y - min_y) + offset_w = math.ceil(self.bbox_offset_factor * bbox_w) + offset_h = math.ceil(self.bbox_offset_factor * bbox_h) + + min_x = max(0, min_x + np.random.randint(-offset_w, offset_w)) + max_x = min(alpha.shape[2] - 1, max_x + np.random.randint(-offset_w, offset_w)) + min_y = max(0, min_y + np.random.randint(-offset_h, offset_h)) + max_y = min(alpha.shape[1] - 1, max_y + np.random.randint(-offset_h, offset_h)) + bbox = torch.tensor([[min_x, min_y, max_x, max_y]]) + else: + bbox = torch.tensor([[min_x, min_y, max_x, max_y]]) + + if self.random_auto_matting is not None and np.random.uniform(0, 1) < self.random_auto_matting: + bbox = torch.tensor([[0, 0, alpha.shape[2] - 1, alpha.shape[1] - 1]]) + + else: + bbox = torch.zeros(1, 4) + + sample['bbox'] = bbox.float() + return sample + +class DataGenerator(Dataset): + def __init__( + self, + data, + phase="train", + crop_size=512, + remove_multi_fg=False, + min_scale_fg_scale=None, + max_scale_fg_scale=None, + with_bbox = False, + bbox_offset_factor = None, + return_keys = None, + random_crop_bbox = None, + dataset_name = None, + random_auto_matting = None, + ): + self.phase = phase + # self.crop_size = CONFIG.data.crop_size + self.crop_size = crop_size + self.remove_multi_fg = remove_multi_fg + self.with_bbox = with_bbox + self.bbox_offset_factor = bbox_offset_factor + self.alpha = data.alpha + self.return_keys = return_keys + self.random_crop_bbox = random_crop_bbox + self.dataset_name = dataset_name + self.random_auto_matting = random_auto_matting + + if self.phase == "train": + self.fg = data.fg + self.bg = data.bg + self.merged = [] + self.trimap = [] + else: + self.fg = [] + self.bg = [] + self.merged = data.merged + self.trimap = data.trimap + + train_trans = [ + RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5), + GenMask(), + CutMask(perturb_prob=CONFIG.data.cutmask_prob), + RandomCrop((self.crop_size, self.crop_size)), + RandomJitter(), + Composite(), + ToTensor(phase="train") + ] + if min_scale_fg_scale is not None: + train_trans.insert(0, ScaleFg(min_scale_fg_scale, max_scale_fg_scale)) + if self.with_bbox: + train_trans.append(GenBBox(bbox_offset_factor=self.bbox_offset_factor, random_crop_bbox=self.random_crop_bbox, random_auto_matting=self.random_auto_matting)) + + test_trans = [ OriginScale(), ToTensor() ] + + self.transform = { + 'train': + transforms.Compose(train_trans), + 'val': + transforms.Compose([ + OriginScale(), + ToTensor() + ]), + 'test': + transforms.Compose(test_trans) + }[phase] + + self.fg_num = len(self.fg) + + def select_keys(self, sample): + new_sample = {} + for key, val in sample.items(): + if key in self.return_keys: + new_sample[key] = val + return new_sample + + def __getitem__(self, idx): + if self.phase == "train": + fg = cv2.imread(self.fg[idx % self.fg_num]) + alpha = cv2.imread(self.alpha[idx % self.fg_num], 0).astype(np.float32)/255 + bg = cv2.imread(self.bg[idx], 1) + + if not self.remove_multi_fg: + fg, alpha, multi_fg = self._composite_fg(fg, alpha, idx) + else: + multi_fg = False + image_name = os.path.split(self.fg[idx % self.fg_num])[-1] + sample = {'fg': fg, 'alpha': alpha, 'bg': bg, 'image_name': image_name, 'multi_fg': multi_fg} + + else: + image = cv2.imread(self.merged[idx]) + alpha = cv2.imread(self.alpha[idx], 0)/255. + trimap = cv2.imread(self.trimap[idx], 0) + mask = (trimap >= 170).astype(np.float32) + image_name = os.path.split(self.merged[idx])[-1] + + sample = {'image': image, 'alpha': alpha, 'trimap': trimap, 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape} + + sample = self.transform(sample) + + if self.return_keys is not None: + sample = self.select_keys(sample) + if self.dataset_name is not None: + sample['dataset_name'] = self.dataset_name + return sample + + def _composite_fg(self, fg, alpha, idx): + + multi_fg = False + if np.random.rand() < 0.5: + idx2 = np.random.randint(self.fg_num) + idx + fg2 = cv2.imread(self.fg[idx2 % self.fg_num]) + alpha2 = cv2.imread(self.alpha[idx2 % self.fg_num], 0).astype(np.float32)/255. + h, w = alpha.shape + fg2 = cv2.resize(fg2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) + alpha2 = cv2.resize(alpha2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) + + alpha_tmp = 1 - (1 - alpha) * (1 - alpha2) + if np.any(alpha_tmp < 1): + fg = fg.astype(np.float32) * alpha[:,:,None] + fg2.astype(np.float32) * (1 - alpha[:,:,None]) + # The overlap of two 50% transparency should be 25% + alpha = alpha_tmp + fg = fg.astype(np.uint8) + multi_fg = True + + if np.random.rand() < 0.25: + # fg = cv2.resize(fg, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) + # alpha = cv2.resize(alpha, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) + fg = cv2.resize(fg, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) + alpha = cv2.resize(alpha, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) + + return fg, alpha, multi_fg + + def __len__(self): + if self.phase == "train": + return len(self.bg) + else: + return len(self.alpha) + + +class ResziePad(object): + + def __init__(self, target_size=1024): + self.target_size = target_size + + def __call__(self, sample): + _, H, W = sample['image'].shape + + scale = self.target_size * 1.0 / max(H, W) + new_H, new_W = H * scale, W * scale + new_W = int(new_W + 0.5) + new_H = int(new_H + 0.5) + + choice = {'image', 'trimap', 'alpha'} if 'trimap' in sample.keys() else {'image', 'alpha'} + for key in choice: + if key in {'image', 'trimap'}: + sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='bilinear', align_corners=False)[0] + else: + # sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='nearest')[0] + sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='bilinear', align_corners=False)[0] + padding = torch.zeros([sample[key].shape[0], self.target_size, self.target_size], dtype=sample[key].dtype, device=sample[key].device) + padding[:, : new_H, : new_W] = sample[key] + sample[key] = padding + + return sample + + +class Cv2ResziePad(object): + + def __init__(self, target_size=1024): + self.target_size = target_size + + def __call__(self, sample): + H, W, _ = sample['image'].shape + + scale = self.target_size * 1.0 / max(H, W) + new_H, new_W = H * scale, W * scale + new_W = int(new_W + 0.5) + new_H = int(new_H + 0.5) + + choice = {'image', 'trimap', 'alpha'} if 'trimap' in sample.keys() and sample['trimap'] is not None else {'image', 'alpha'} + for key in choice: + sample[key] = cv2.resize(sample[key], (new_W, new_H), interpolation=cv2.INTER_LINEAR) # cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC + + if key == 'image': + padding = np.zeros([self.target_size, self.target_size, sample[key].shape[-1]], dtype=sample[key].dtype) + padding[: new_H, : new_W, :] = sample[key] + sample[key] = padding + sample[key] = sample[key][:, :, ::-1].transpose((2, 0, 1)).astype(np.float32) #/ 255.0 + else: + padding = np.zeros([self.target_size, self.target_size], dtype=sample[key].dtype) + padding[: new_H, : new_W] = sample[key] + sample[key] = padding + sample[key] = sample[key][None].astype(np.float32) + sample[key] = torch.from_numpy(sample[key]) + + return sample + + +class AdobeCompositionTest(Dataset): + def __init__(self, data_dir, target_size=1024, multi_fg=None): + self.data_dir = data_dir + self.file_names = sorted(os.listdir(os.path.join(self.data_dir, 'merged'))) + + test_trans = [ + ResziePad(target_size=target_size), + GenBBox(bbox_offset_factor=0) + ] + self.transform = transforms.Compose(test_trans) + self.multi_fg = multi_fg + + def __len__(self): # 1000 + return len(self.file_names) + + def __getitem__(self, idx): + phas = Image.open(os.path.join(self.data_dir, 'alpha_copy', self.file_names[idx])).convert('L') + tris = Image.open(os.path.join(self.data_dir, 'trimaps', self.file_names[idx])) + imgs = Image.open(os.path.join(self.data_dir, 'merged', self.file_names[idx])) + sample = { + 'ori_h_w': (imgs.size[1], imgs.size[0]), + 'data_type': 'Adobe' + } + + sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0 + sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0 + sample['image'] = torchvision.transforms.functional.to_tensor(imgs) + sample['image_name'] = 'Adobe_' + self.file_names[idx] + + sample = self.transform(sample) + sample['trimap'][sample['trimap'] < 85] = 0 + sample['trimap'][sample['trimap'] >= 170] = 1 + sample['trimap'][sample['trimap'] >= 85] = 0.5 + + if self.multi_fg is not None: + sample['multi_fg'] = torch.tensor(self.multi_fg) + + return sample + + +class SIMTest(Dataset): + def __init__(self, data_dir, target_size=1024, multi_fg=None): + self.data_dir = data_dir + self.file_names = sorted(glob.glob(os.path.join(*[data_dir, '*', 'alpha', '*']))) # [: 10] + test_trans = [ + ResziePad(target_size=target_size), + GenBBox(bbox_offset_factor=0) + ] + self.transform = transforms.Compose(test_trans) + self.multi_fg = multi_fg + + def __len__(self): # 1000 + return len(self.file_names) + + def __getitem__(self, idx): + phas = Image.open(self.file_names[idx]).convert('L') + # tris = Image.open(self.file_names[idx].replace('alpha', 'trimap')) + imgs = Image.open(self.file_names[idx].replace('alpha', 'merged')) + sample = { + 'ori_h_w': (imgs.size[1], imgs.size[0]), + 'data_type': 'SIM' + } + + sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0 + # sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0 + sample['image'] = torchvision.transforms.functional.to_tensor(imgs) + sample['image_name'] = 'SIM_{}_{}'.format(self.file_names[idx].split('/')[-3], self.file_names[idx].split('/')[-1]) + + sample = self.transform(sample) + # sample['trimap'][sample['trimap'] < 85] = 0 + # sample['trimap'][sample['trimap'] >= 170] = 1 + # sample['trimap'][sample['trimap'] >= 85] = 0.5 + + if self.multi_fg is not None: + sample['multi_fg'] = torch.tensor(self.multi_fg) + + return sample + + +class RW100Test(Dataset): + def __init__(self, data_dir, target_size=1024, multi_fg=None): + self.data_dir = data_dir + self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'mask', '*']))) + + self.name_to_idx = dict() + for idx, file_name in enumerate(self.file_names): + self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx + + test_trans = [ + ResziePad(target_size=target_size), + GenBBox(bbox_offset_factor=0, train_or_test='test', dataset_type='RW100') + ] + self.transform = transforms.Compose(test_trans) + self.multi_fg = multi_fg + + def __len__(self): # 1000 + return len(self.file_names) + + def __getitem__(self, idx): + phas = Image.open(self.file_names[idx]).convert('L') + imgs = Image.open(self.file_names[idx].replace('mask', 'image')[:-6] + '.jpg') + sample = { + 'ori_h_w': (imgs.size[1], imgs.size[0]), + 'data_type': 'RW100' + } + + sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0 + sample['image'] = torchvision.transforms.functional.to_tensor(imgs) + sample['image_name'] = 'RW100_' + self.file_names[idx].split('/')[-1] + + sample = self.transform(sample) + + if self.multi_fg is not None: + sample['multi_fg'] = torch.tensor(self.multi_fg) + + return sample + + +class AIM500Test(Dataset): + def __init__(self, data_dir, target_size=1024, multi_fg=None): + self.data_dir = data_dir + self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'original', '*']))) + + self.name_to_idx = dict() + for idx, file_name in enumerate(self.file_names): + self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx + + test_trans = [ + ResziePad(target_size=target_size), + GenBBox(bbox_offset_factor=0) + ] + self.transform = transforms.Compose(test_trans) + self.multi_fg = multi_fg + + def __len__(self): # 1000 + return len(self.file_names) + + def __getitem__(self, idx): + phas = Image.open(self.file_names[idx].replace('original', 'mask').replace('jpg', 'png')).convert('L') + # tris = Image.open(self.file_names[idx].replace('original', 'trimap').replace('jpg', 'png')).convert('L') + imgs = Image.open(self.file_names[idx]) + sample = { + 'ori_h_w': (imgs.size[1], imgs.size[0]), + 'data_type': 'AIM500' + } + + sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0 + # sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0 + sample['image'] = torchvision.transforms.functional.to_tensor(imgs) + sample['image_name'] = 'AIM500_' + self.file_names[idx].split('/')[-1] + + sample = self.transform(sample) + # sample['trimap'][sample['trimap'] < 85] = 0 + # sample['trimap'][sample['trimap'] >= 170] = 1 + # sample['trimap'][sample['trimap'] >= 85] = 0.5 + + if self.multi_fg is not None: + sample['multi_fg'] = torch.tensor(self.multi_fg) + + return sample + + +class RWP636Test(Dataset): + def __init__(self, data_dir, target_size=1024, multi_fg=None): + self.data_dir = data_dir + self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'image', '*']))) + + self.name_to_idx = dict() + for idx, file_name in enumerate(self.file_names): + self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx + + test_trans = [ + ResziePad(target_size=target_size), + GenBBox(bbox_offset_factor=0) + ] + self.transform = transforms.Compose(test_trans) + self.multi_fg = multi_fg + + def __len__(self): # 1000 + return len(self.file_names) + + def __getitem__(self, idx): + phas = Image.open(self.file_names[idx].replace('image', 'alpha').replace('jpg', 'png')).convert('L') + imgs = Image.open(self.file_names[idx]) + sample = { + 'ori_h_w': (imgs.size[1], imgs.size[0]), + 'data_type': 'RWP636' + } + + sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0 + sample['image'] = torchvision.transforms.functional.to_tensor(imgs) + sample['image_name'] = 'RWP636_' + self.file_names[idx].split('/')[-1] + + sample = self.transform(sample) + + if self.multi_fg is not None: + sample['multi_fg'] = torch.tensor(self.multi_fg) + + return sample + + +class AM2KTest(Dataset): + def __init__(self, data_dir, target_size=1024, multi_fg=None): + self.data_dir = data_dir + self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'validation/original', '*']))) + test_trans = [ + ResziePad(target_size=target_size), + GenBBox(bbox_offset_factor=0) + ] + self.transform = transforms.Compose(test_trans) + self.multi_fg = multi_fg + + def __len__(self): # 1000 + return len(self.file_names) + + def __getitem__(self, idx): + phas = Image.open(self.file_names[idx].replace('original', 'mask').replace('jpg', 'png')).convert('L') + # tris = Image.open(self.file_names[idx].replace('original', 'trimap').replace('jpg', 'png')).convert('L') + imgs = Image.open(self.file_names[idx]) + sample = { + 'ori_h_w': (imgs.size[1], imgs.size[0]), + 'data_type': 'AM2K' + } + + sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0 + # sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0 + sample['image'] = torchvision.transforms.functional.to_tensor(imgs) + sample['image_name'] = 'AM2K_' + self.file_names[idx].split('/')[-1] + + sample = self.transform(sample) + # sample['trimap'][sample['trimap'] < 85] = 0 + # sample['trimap'][sample['trimap'] >= 170] = 1 + # sample['trimap'][sample['trimap'] >= 85] = 0.5 + + if self.multi_fg is not None: + sample['multi_fg'] = torch.tensor(self.multi_fg) + + return sample + + +class P3M500Test(Dataset): + def __init__(self, data_dir, target_size=1024, multi_fg=None): + self.data_dir = data_dir + self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'original_image', '*']))) + + self.name_to_idx = dict() + for idx, file_name in enumerate(self.file_names): + self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx + + test_trans = [ + ResziePad(target_size=target_size), + GenBBox(bbox_offset_factor=0) + ] + self.transform = transforms.Compose(test_trans) + self.multi_fg = multi_fg + + def __len__(self): # 1000 + return len(self.file_names) + + def __getitem__(self, idx): + phas = Image.open(self.file_names[idx].replace('original_image', 'mask').replace('jpg', 'png')).convert('L') + # tris = Image.open(self.file_names[idx].replace('original_image', 'trimap').replace('jpg', 'png')).convert('L') + imgs = Image.open(self.file_names[idx]) + sample = { + 'ori_h_w': (imgs.size[1], imgs.size[0]), + 'data_type': 'P3M500' + } + + sample['alpha'] = torchvision.transforms.functional.to_tensor(phas) # [1, H, W] 0.0 ~ 1.0 + # sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0 + sample['image'] = torchvision.transforms.functional.to_tensor(imgs) + sample['image_name'] = 'P3M500_' + self.file_names[idx].split('/')[-1] + + sample = self.transform(sample) + # sample['trimap'][sample['trimap'] < 85] = 0 + # sample['trimap'][sample['trimap'] >= 170] = 1 + # sample['trimap'][sample['trimap'] >= 85] = 0.5 + + if self.multi_fg is not None: + sample['multi_fg'] = torch.tensor(self.multi_fg) + + return sample + + +class MattingTest(Dataset): + def __init__( + self, + data_type, + data_dir, + image_sub_path, + alpha_sub_path, + trimpa_sub_path=None, + target_size=1024, + multi_fg=None, + ): + self.data_type = data_type + self.data_dir = data_dir + + self.image_paths = sorted(glob.glob(os.path.join(*[data_dir, image_sub_path]))) + self.alpha_paths = sorted(glob.glob(os.path.join(*[data_dir, alpha_sub_path]))) + self.trimpa_paths = sorted(glob.glob(os.path.join(*[data_dir, trimpa_sub_path]))) if trimpa_sub_path is not None else None + + self.name_to_idx = dict() + for idx, file_name in enumerate(self.image_paths): + self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx + + test_trans = [ + Cv2ResziePad(target_size=target_size), + GenBBox(bbox_offset_factor=0) + ] + self.transform = transforms.Compose(test_trans) + self.multi_fg = multi_fg + + def __len__(self): # 1000 + return len(self.image_paths) + + def __getitem__(self, idx): + + img = cv2.imread(self.image_paths[idx]) + sample = { + 'image': img.astype(np.float32) / 255, + 'alpha': cv2.imread(self.alpha_paths[idx], 0).astype(np.float32) / 255, + 'trimap': cv2.imread(self.trimpa_paths[idx], 0) if self.trimpa_paths is not None else None, + 'ori_h_w': (img.shape[0], img.shape[1]), + 'data_type': self.data_type, + 'image_name': self.data_type + '_' + self.image_paths[idx].split('/')[-1] + } + + sample = self.transform(sample) + if self.trimpa_paths is not None: + sample['trimap'][sample['trimap'] < 85] = 0 + sample['trimap'][sample['trimap'] >= 170] = 1 + sample['trimap'][sample['trimap'] >= 85] = 0.5 + else: + del sample['trimap'] + + if self.multi_fg is not None: + sample['multi_fg'] = torch.tensor(self.multi_fg) + + return sample + + +def adobe_composition_collate_fn(batch): + new_batch = defaultdict(list) + for sub_batch in batch: + for key in sub_batch.keys(): + new_batch[key].append(sub_batch[key]) + for key in new_batch: + if isinstance(new_batch[key][0], torch.Tensor): + new_batch[key] = torch.stack(new_batch[key]) + return dict(new_batch) + + +def build_d2_test_dataloader( + dataset, + mapper=None, + total_batch_size=None, + local_batch_size=None, + num_workers=0, + collate_fn=None +): + + assert (total_batch_size is None) != ( + local_batch_size is None + ), "Either total_batch_size or local_batch_size must be specified" + + world_size = comm.get_world_size() + + if total_batch_size is not None: + assert ( + total_batch_size > 0 and total_batch_size % world_size == 0 + ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format( + total_batch_size, world_size + ) + batch_size = total_batch_size // world_size + + if local_batch_size is not None: + batch_size = local_batch_size + + logger = logging.getLogger(__name__) + if batch_size != 1: + logger.warning( + "When testing, batch size is set to 1. " + "This is the only mode that is supported for d2." + ) + + return build_detection_test_loader( + dataset=dataset, + mapper=mapper, + sampler=None, + num_workers=num_workers, + collate_fn=collate_fn, + ) + + +class AdobeCompositionEvaluator(DatasetEvaluator): + + def __init__( + self, + save_eval_results_step=-1, + output_dir=None, + eval_dataset_type=['Adobe'], + distributed=True, + eval_w_sam_hq_mask = False, + ): + + self.save_eval_results_step = save_eval_results_step + self.output_dir = output_dir + self.eval_index = 0 + self.eval_dataset_type = eval_dataset_type + self.eval_w_sam_hq_mask = eval_w_sam_hq_mask + + self._distributed = distributed + self._logger = logging.getLogger(__name__) + + def reset(self): + self.eval_metric = dict() + for i in self.eval_dataset_type: + self.eval_metric[i + '_MSE'] = [] + self.eval_metric[i + '_SAD'] = [] + self.eval_metric[i + '_MAD'] = [] + self.eval_metric[i + '_Grad'] = [] + self.eval_metric[i + '_Conn'] = [] + + os.makedirs(self.output_dir, exist_ok=True) if self.output_dir is not None else None + + def process(self, inputs, outputs): + """ + Args: + inputs: {'alpha', 'trimap', 'image', 'bbox', 'image_name'} + outputs: [1, 1, H, W] 0. ~ 1. + """ + + # crop the black pad area + assert inputs['image'].shape[-1] == inputs['image'].shape[-2] == 1024 and len(inputs['ori_h_w']) == 1 + inputs['ori_h_w'] = inputs['ori_h_w'][0] + before_pad_h, before_pad_w = int(1024 / max(inputs['ori_h_w']) * inputs['ori_h_w'][0] + 0.5), int(1024 / max(inputs['ori_h_w']) * inputs['ori_h_w'][1] + 0.5) + inputs['image'] = inputs['image'][:, :, :before_pad_h, :before_pad_w] + inputs['alpha'] = inputs['alpha'][:, :, :before_pad_h, :before_pad_w] + + if self.eval_w_sam_hq_mask: + outputs, samhq_low_res_masks = outputs[0][:, :, :before_pad_h, :before_pad_w], outputs[1][:, :, :before_pad_h, :before_pad_w] + pred_alpha, label_alpha, samhq_low_res_masks = outputs.cpu().numpy(), inputs['alpha'].numpy(), (samhq_low_res_masks > 0).float().cpu() + else: + outputs = outputs[:, :, :before_pad_h, :before_pad_w] + pred_alpha, label_alpha = outputs.cpu().numpy(), inputs['alpha'].numpy() + + # if 'trimap' in inputs.keys(): + # inputs['trimap'] = inputs['trimap'][:, :, :before_pad_h, :before_pad_w] + # trimap = inputs['trimap'].numpy() + # assert np.max(trimap) <= 1 and len(np.unique(trimap)) <= 3 + # sad_loss_unknown = compute_sad_loss(pred_alpha, label_alpha, trimap, area='unknown') + # mse_loss_unknown = compute_mse_loss(pred_alpha, label_alpha, trimap, area='unknown') + + # self.eval_metric[inputs['data_type'][0] + '_unknown_mse (1e-3)'].append(mse_loss_unknown) + # self.eval_metric[inputs['data_type'][0] + '_unknown_sad (1e3)'].append(sad_loss_unknown) + + # calculate loss + assert np.max(pred_alpha) <= 1 and np.max(label_alpha) <= 1 + eval_pred = np.uint8(pred_alpha[0, 0] * 255.0 + 0.5) * 1.0 + eval_gt = label_alpha[0, 0] * 255.0 + + detailmap = np.zeros_like(eval_gt) + 128 + mse_loss_ = compute_mse_loss(eval_pred, eval_gt, detailmap) + sad_loss_ = compute_sad_loss(eval_pred, eval_gt, detailmap)[0] + mad_loss_ = compute_mad_loss(eval_pred, eval_gt, detailmap) + grad_loss_ = compute_gradient_loss(eval_pred, eval_gt, detailmap) + conn_loss_ = compute_connectivity_error(eval_pred, eval_gt, detailmap) + + self.eval_metric[inputs['data_type'][0] + '_MSE'].append(mse_loss_) + self.eval_metric[inputs['data_type'][0] + '_SAD'].append(sad_loss_) + self.eval_metric[inputs['data_type'][0] + '_MAD'].append(mad_loss_) + self.eval_metric[inputs['data_type'][0] + '_Grad'].append(grad_loss_) + self.eval_metric[inputs['data_type'][0] + '_Conn'].append(conn_loss_) + + # vis results + if self.save_eval_results_step != -1 and self.eval_index % self.save_eval_results_step == 0: + if self.eval_w_sam_hq_mask: + self.save_vis_results(inputs, pred_alpha, samhq_low_res_masks) + else: + self.save_vis_results(inputs, pred_alpha) + self.eval_index += 1 + + def save_vis_results(self, inputs, pred_alpha, samhq_low_res_masks=None): + + # image + image = inputs['image'][0].permute(1, 2, 0) * 255.0 + l, u, r, d = int(inputs['bbox'][0, 0, 0].item()), int(inputs['bbox'][0, 0, 1].item()), int(inputs['bbox'][0, 0, 2].item()), int(inputs['bbox'][0, 0, 3].item()) + red_line = torch.tensor([[255., 0., 0.]], device=image.device, dtype=image.dtype) + image[u: d, l, :] = red_line + image[u: d, r, :] = red_line + image[u, l: r, :] = red_line + image[d, l: r, :] = red_line + image = np.uint8(image.numpy()) + + # trimap, pred_alpha, label_alpha + save_results = [image] + + choice = [inputs['trimap'], torch.from_numpy(pred_alpha), inputs['alpha']] if 'trimap' in inputs.keys() else [torch.from_numpy(pred_alpha), inputs['alpha']] + for val in choice: + val = val[0].permute(1, 2, 0).repeat(1, 1, 3) * 255.0 + 0.5 # +0.5 and int() = round() + val = np.uint8(val.numpy()) + save_results.append(val) + + if samhq_low_res_masks is not None: + save_results.append(np.uint8(samhq_low_res_masks[0].permute(1, 2, 0).repeat(1, 1, 3).numpy() * 255.0)) + + save_results = np.concatenate(save_results, axis=1) + save_name = os.path.join(self.output_dir, inputs['image_name'][0]) + Image.fromarray(save_results).save(save_name.replace('.jpg', '.png')) + + def evaluate(self): + + if self._distributed: + comm.synchronize() + eval_metric = comm.gather(self.eval_metric, dst=0) + + if not comm.is_main_process(): + return {} + + merges_eval_metric = defaultdict(list) + for sub_eval_metric in eval_metric: + for key, val in sub_eval_metric.items(): + merges_eval_metric[key] += val + eval_metric = merges_eval_metric + + else: + eval_metric = self.eval_metric + + eval_results = {} + + for key, val in eval_metric.items(): + if len(val) != 0: + # if 'mse' in key: + # eval_results[key] = np.array(val).mean() * 1e3 + # else: + # assert 'sad' in key + # eval_results[key] = np.array(val).mean() / 1e3 + eval_results[key] = np.array(val).mean() + + return eval_results + + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/data/evaluate.py b/data/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..53147cd3b3d2e6a14b2f4f90ed768d3bb9e5c0fd --- /dev/null +++ b/data/evaluate.py @@ -0,0 +1,102 @@ +import scipy.ndimage +import numpy as np +from skimage.measure import label +import scipy.ndimage.morphology + + +def gauss(x, sigma): + y = np.exp(-x ** 2 / (2 * sigma ** 2)) / (sigma * np.sqrt(2 * np.pi)) + return y + + +def dgauss(x, sigma): + y = -x * gauss(x, sigma) / (sigma ** 2) + return y + + +def gaussgradient(im, sigma): + epsilon = 1e-2 + halfsize = np.ceil(sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))).astype(np.int32) + size = 2 * halfsize + 1 + hx = np.zeros((size, size)) + for i in range(0, size): + for j in range(0, size): + u = [i - halfsize, j - halfsize] + hx[i, j] = gauss(u[0], sigma) * dgauss(u[1], sigma) + + hx = hx / np.sqrt(np.sum(np.abs(hx) * np.abs(hx))) + hy = hx.transpose() + + gx = scipy.ndimage.convolve(im, hx, mode='nearest') + gy = scipy.ndimage.convolve(im, hy, mode='nearest') + + return gx, gy + + +def compute_gradient_loss(pred, target, trimap): + + pred = pred / 255.0 + target = target / 255.0 + + pred_x, pred_y = gaussgradient(pred, 1.4) + target_x, target_y = gaussgradient(target, 1.4) + + pred_amp = np.sqrt(pred_x ** 2 + pred_y ** 2) + target_amp = np.sqrt(target_x ** 2 + target_y ** 2) + + error_map = (pred_amp - target_amp) ** 2 + loss = np.sum(error_map[trimap == 128]) + + return loss / 1000. + + +def getLargestCC(segmentation): + labels = label(segmentation, connectivity=1) + largestCC = labels == np.argmax(np.bincount(labels.flat)) + return largestCC + + +def compute_connectivity_error(pred, target, trimap, step=0.1): + pred = pred / 255.0 + target = target / 255.0 + h, w = pred.shape + + thresh_steps = list(np.arange(0, 1 + step, step)) + l_map = np.ones_like(pred, dtype=np.float32) * -1 + for i in range(1, len(thresh_steps)): + pred_alpha_thresh = (pred >= thresh_steps[i]).astype(np.int32) + target_alpha_thresh = (target >= thresh_steps[i]).astype(np.int32) + + omega = getLargestCC(pred_alpha_thresh * target_alpha_thresh).astype(np.int32) + flag = ((l_map == -1) & (omega == 0)).astype(np.int32) + l_map[flag == 1] = thresh_steps[i - 1] + + l_map[l_map == -1] = 1 + + pred_d = pred - l_map + target_d = target - l_map + pred_phi = 1 - pred_d * (pred_d >= 0.15).astype(np.int32) + target_phi = 1 - target_d * (target_d >= 0.15).astype(np.int32) + loss = np.sum(np.abs(pred_phi - target_phi)[trimap == 128]) + + return loss / 1000. + + +def compute_mse_loss(pred, target, trimap): + error_map = (pred - target) / 255.0 + loss = np.sum((error_map ** 2) * (trimap == 128)) / (np.sum(trimap == 128) + 1e-8) + + return loss + + +def compute_sad_loss(pred, target, trimap): + error_map = np.abs((pred - target) / 255.0) + loss = np.sum(error_map * (trimap == 128)) + + return loss / 1000, np.sum(trimap == 128) / 1000 + +def compute_mad_loss(pred, target, trimap): + error_map = np.abs((pred - target) / 255.0) + loss = np.sum(error_map * (trimap == 128)) / (np.sum(trimap == 128) + 1e-8) + + return loss diff --git a/data/p3m10k_dataset.py b/data/p3m10k_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..70806ac92a87c1dd30092f5f5eb9e92e5a95a1e5 --- /dev/null +++ b/data/p3m10k_dataset.py @@ -0,0 +1,325 @@ +import os +import torch +import numpy as np +import cv2 +from torch.utils.data import Dataset +from torchvision import transforms +import math +import torch.nn.functional as F + + +class GenBBox(object): + def __init__(self, bbox_offset_factor = 0.1, random_crop_bbox = None, train_or_test = 'train', dataset_type = None, random_auto_matting=None): + self.bbox_offset_factor = bbox_offset_factor + self.random_crop_bbox = random_crop_bbox + self.train_or_test = train_or_test + self.dataset_type = dataset_type + self.random_auto_matting = random_auto_matting + + def __call__(self, sample): + + alpha = sample['alpha'] # [1, H, W] 0.0 ~ 1.0 + indices = torch.nonzero(alpha[0], as_tuple=True) + + if len(indices[0]) > 0: + + min_x, min_y = torch.min(indices[1]), torch.min(indices[0]) + max_x, max_y = torch.max(indices[1]), torch.max(indices[0]) + + if self.random_crop_bbox is not None and np.random.uniform(0, 1) < self.random_crop_bbox: + ori_h_w = (sample['alpha'].shape[-2], sample['alpha'].shape[-1]) + sample['alpha'] = F.interpolate(sample['alpha'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0] + sample['image'] = F.interpolate(sample['image'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0] + sample['trimap'] = F.interpolate(sample['trimap'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='nearest')[0] + bbox = torch.tensor([[0, 0, ori_h_w[1] - 1, ori_h_w[0] - 1]]) + + elif self.bbox_offset_factor != 0: + bbox_w = max(1, max_x - min_x) + bbox_h = max(1, max_y - min_y) + offset_w = math.ceil(self.bbox_offset_factor * bbox_w) + offset_h = math.ceil(self.bbox_offset_factor * bbox_h) + + min_x = max(0, min_x + np.random.randint(-offset_w, offset_w)) + max_x = min(alpha.shape[2] - 1, max_x + np.random.randint(-offset_w, offset_w)) + min_y = max(0, min_y + np.random.randint(-offset_h, offset_h)) + max_y = min(alpha.shape[1] - 1, max_y + np.random.randint(-offset_h, offset_h)) + bbox = torch.tensor([[min_x, min_y, max_x, max_y]]) + else: + bbox = torch.tensor([[min_x, min_y, max_x, max_y]]) + + if self.random_auto_matting is not None and np.random.uniform(0, 1) < self.random_auto_matting: + bbox = torch.tensor([[0, 0, alpha.shape[2] - 1, alpha.shape[1] - 1]]) + + else: + bbox = torch.zeros(1, 4) + + sample['bbox'] = bbox.float() + return sample + +def random_interp(): + return np.random.choice([cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]) + + +class SplitConcatImage(object): + + def __init__(self, concat_num=4, wo_mask_to_mattes=False): + self.concat_num = concat_num + self.wo_mask_to_mattes = wo_mask_to_mattes + if self.wo_mask_to_mattes: + assert self.concat_num == 5 + + def __call__(self, concat_image): + if isinstance(concat_image, list): + concat_image, image_path = concat_image[0], concat_image[1] + else: + image_path = None + H, W, _ = concat_image.shape + + concat_num = self.concat_num + if image_path is not None: + if '06-14' in image_path: + concat_num = 4 + elif 'ori_mask' in image_path or 'SEMat' in image_path: + concat_num = 3 + else: + concat_num = 5 + + assert W % concat_num == 0 + W = W // concat_num + + image = concat_image[:H, :W] + if self.concat_num != 3: + trimap = concat_image[:H, (concat_num - 2) * W: (concat_num - 1) * W] + if self.wo_mask_to_mattes: + alpha = concat_image[:H, 2 * W: 3 * W] + else: + alpha = concat_image[:H, (concat_num - 1) * W: concat_num * W] + else: + trimap = concat_image[:H, (concat_num - 1) * W: concat_num * W] + alpha = concat_image[:H, (concat_num - 2) * W: (concat_num - 1) * W] + + return {'image': image, 'trimap': trimap, 'alpha': alpha} + + +class RandomHorizontalFlip(object): + + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, sample): + if np.random.uniform(0, 1) < self.prob: + for key in sample.keys(): + sample[key] = cv2.flip(sample[key], 1) + return sample + +class EmptyAug(object): + def __call__(self, sample): + return sample + +class RandomReszieCrop(object): + + def __init__(self, output_size=1024, aug_scale_min=0.5, aug_scale_max=1.5): + self.desired_size = output_size + self.aug_scale_min = aug_scale_min + self.aug_scale_max = aug_scale_max + + def __call__(self, sample): + H, W, _ = sample['image'].shape + sample['trimap'] = sample['trimap'][:, :, None].repeat(3, axis=-1) + sample['alpha'] = sample['alpha'][:, :, None].repeat(3, axis=-1) + + if self.aug_scale_min == 1.0 and self.aug_scale_max == 1.0: + crop_H, crop_W = H, W + crop_y1, crop_y2 = 0, crop_H + crop_x1, crop_x2 = 0, crop_W + scale_W, scaled_H = W, H + elif self.aug_scale_min == -1.0 and self.aug_scale_max == -1.0: + scale = min(self.desired_size / H, self.desired_size / W) + scaled_H, scale_W = round(H * scale), round(W * scale) + crop_H, crop_W = scaled_H, scale_W + crop_y1, crop_y2 = 0, crop_H + crop_x1, crop_x2 = 0, crop_W + else: + # random size + random_scale = np.random.uniform(0, 1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min # random_val: 0.5 ~ 1.5 + scaled_size = round(random_scale * self.desired_size) + + scale = min(scaled_size / H, scaled_size / W) + scaled_H, scale_W = round(H * scale), round(W * scale) + + # random crop + crop_H, crop_W = min(self.desired_size, scaled_H), min(self.desired_size, scale_W) # crop_size + margin_H, margin_W = max(scaled_H - crop_H, 0), max(scale_W - crop_W, 0) + offset_H, offset_W = np.random.randint(0, margin_H + 1), np.random.randint(0, margin_W + 1) + crop_y1, crop_y2 = offset_H, offset_H + crop_H + crop_x1, crop_x2 = offset_W, offset_W + crop_W + + for key in sample.keys(): + sample[key] = cv2.resize(sample[key], (scale_W, scaled_H), interpolation=random_interp())[crop_y1: crop_y2, crop_x1: crop_x2, :] # resize and crop + padding = np.zeros(shape=(self.desired_size, self.desired_size, 3), dtype=sample[key].dtype) # pad to desired_size + padding[: crop_H, : crop_W, :] = sample[key] + sample[key] = padding + + return sample + + +class RandomJitter(object): + """ + Random change the hue of the image + """ + + def __call__(self, sample): + + image = sample['image'] + + # convert to HSV space, convert to float32 image to keep precision during space conversion. + image = cv2.cvtColor(image.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV) + # Hue noise + hue_jitter = np.random.randint(-40, 40) + image[:, :, 0] = np.remainder(image[:, :, 0].astype(np.float32) + hue_jitter, 360) + # Saturation noise + sat_bar = image[:, :, 1].mean() + + sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10 + sat = image[:, :, 1] + sat = np.abs(sat + sat_jitter) + sat[sat>1] = 2 - sat[sat>1] + image[:, :, 1] = sat + # Value noise + val_bar = image[:, :, 2].mean() + + val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10 + val = image[:, :, 2] + val = np.abs(val + val_jitter) + val[val>1] = 2 - val[val>1] + image[:, :, 2] = val + # convert back to BGR space + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) + sample['image'] = image * 255 + + return sample + + +class ToTensor(object): + + def __call__(self, sample): + image, alpha, trimap = sample['image'][:, :, ::-1], sample['alpha'], sample['trimap'] + + # image + image = image.transpose((2, 0, 1)) / 255. + sample['image'] = torch.from_numpy(image).float() + + # alpha + alpha = alpha.transpose((2, 0, 1))[0: 1] / 255. + alpha[alpha < 0 ] = 0 + alpha[alpha > 1] = 1 + sample['alpha'] = torch.from_numpy(alpha).float() + + # trimap + trimap = trimap.transpose((2, 0, 1))[0: 1] / 1. + sample['trimap'] = torch.from_numpy(trimap).float() + sample['trimap'][sample['trimap'] < 85] = 0 + sample['trimap'][sample['trimap'] >= 170] = 1 + sample['trimap'][sample['trimap'] >= 85] = 0.5 + + return sample + + +class GenTrimap(object): + def __init__(self): + self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,100)] + + def __call__(self, sample): + alpha = sample['alpha'] + h, w = alpha.shape + + max_kernel_size = max(30, int((min(h,w) / 2048) * 30)) + + ### generate trimap + fg_mask = (alpha / 255.0 + 1e-5).astype(np.int32).astype(np.uint8) + bg_mask = (1 - alpha / 255.0 + 1e-5).astype(np.int32).astype(np.uint8) + fg_mask = cv2.erode(fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) + bg_mask = cv2.erode(bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) + + trimap = np.ones_like(alpha) * 128 + trimap[fg_mask == 1] = 255 + trimap[bg_mask == 1] = 0 + + trimap = cv2.resize(trimap, (w,h), interpolation=cv2.INTER_NEAREST) + sample['trimap'] = trimap + + return sample + + +class P3MData(Dataset): + def __init__( + self, + data_root_path = '/root/data/my_path_b/public_data/data/matting/P3M-10k/train/blurred_image/', + output_size = 1024, + aug_scale_min = 0.8, + aug_scale_max = 1.5, + with_bbox = True, + bbox_offset_factor = 0.05, + num_ratio = 4.06, # 9421 * 4.06 = 38249.26 (38251) + ): + + self.data_root_path = data_root_path + self.output_size = output_size + self.aug_scale_min = aug_scale_min + self.aug_scale_max = aug_scale_max + self.with_bbox = with_bbox + self.bbox_offset_factor = bbox_offset_factor + self.num_ratio = num_ratio + + self.image_names = os.listdir(self.data_root_path) + self.image_names = [i for i in self.image_names if 'jpg' in i] + self.image_names.sort() + + train_trans = [ + RandomHorizontalFlip(prob=0 if hasattr(self, 'return_image_name') and self.return_image_name else 0.5), + GenTrimap(), + RandomReszieCrop(self.output_size, self.aug_scale_min, self.aug_scale_max), + RandomJitter(), + ToTensor(), + GenBBox(bbox_offset_factor=self.bbox_offset_factor) + ] + self.transform = transforms.Compose(train_trans) + + def __getitem__(self, idx): + + if self.num_ratio is not None: + if self.num_ratio < 1.0: + idx = np.random.randint(0, len(self.image_names)) + else: + idx = idx % len(self.image_names) + + image_path = os.path.join(self.data_root_path, self.image_names[idx]) + alpha_path = image_path.replace('jpg', 'png').replace('blurred_image', 'mask') + + sample = self.transform({ + 'image': cv2.imread(image_path), + 'alpha': cv2.imread(alpha_path, 0), + }) + + sample['dataset_name'] = 'P3M' + sample['multi_fg'] = False + + return sample + + def __len__(self): + if self.num_ratio is not None: + return int(len(self.image_names) * self.num_ratio) + else: + return len(self.image_names) + + +if __name__ == '__main__': + + dataset = P3MData() + data = dataset[0] + print(len(dataset)) + for key, val in data.items(): + if isinstance(val, torch.Tensor): + print(key, val.shape, torch.min(val), torch.max(val), torch.unique(val)) + else: + print(key, val) \ No newline at end of file diff --git a/data/rand_augment.py b/data/rand_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..89b10cf2c348e9cd39b89c272d38ff5479135f67 --- /dev/null +++ b/data/rand_augment.py @@ -0,0 +1,196 @@ +# copyright: https://github.com/ildoonet/pytorch-randaugment +# code in this file is adpated from rpmcruz/autoaugment +# https://github.com/rpmcruz/autoaugment/blob/master/transformations.py +# This code is modified version of one of ildoonet, for randaugmentation of fixmatch. + +import random + +import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + + +def AutoContrast(img, _): + return PIL.ImageOps.autocontrast(img) + + +def Brightness(img, v): + assert v >= 0.0 + return PIL.ImageEnhance.Brightness(img).enhance(v) + + +def Color(img, v): + assert v >= 0.0 + return PIL.ImageEnhance.Color(img).enhance(v) + + +def Contrast(img, v): + assert v >= 0.0 + return PIL.ImageEnhance.Contrast(img).enhance(v) + + +def Equalize(img, _): + return PIL.ImageOps.equalize(img) + + +def Invert(img, _): + return PIL.ImageOps.invert(img) + + +def Identity(img, v): + return img + + +def Posterize(img, v): # [4, 8] + v = int(v) + v = max(1, v) + return PIL.ImageOps.posterize(img, v) + + +def Rotate(img, v): # [-30, 30] + #assert -30 <= v <= 30 + #if random.random() > 0.5: + # v = -v + return img.rotate(v) + + + +def Sharpness(img, v): # [0.1,1.9] + assert v >= 0.0 + return PIL.ImageEnhance.Sharpness(img).enhance(v) + + +def ShearX(img, v): # [-0.3, 0.3] + #assert -0.3 <= v <= 0.3 + #if random.random() > 0.5: + # v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) + + +def ShearY(img, v): # [-0.3, 0.3] + #assert -0.3 <= v <= 0.3 + #if random.random() > 0.5: + # v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) + + +def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + #assert -0.3 <= v <= 0.3 + #if random.random() > 0.5: + # v = -v + v = v * img.size[0] + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) + + +def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + #assert v >= 0.0 + #if random.random() > 0.5: + # v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) + + +def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + #assert -0.3 <= v <= 0.3 + #if random.random() > 0.5: + # v = -v + v = v * img.size[1] + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) + + +def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + #assert 0 <= v + #if random.random() > 0.5: + # v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) + + +def Solarize(img, v): # [0, 256] + assert 0 <= v <= 256 + return PIL.ImageOps.solarize(img, v) + + +def Cutout(img, v): #[0, 60] => percentage: [0, 0.2] => change to [0, 0.5] + assert 0.0 <= v <= 0.5 + if v <= 0.: + return img + + v = v * img.size[0] + return CutoutAbs(img, v) + + +def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] + # assert 0 <= v <= 20 + if v < 0: + return img + w, h = img.size + x0 = np.random.uniform(w) + y0 = np.random.uniform(h) + + x0 = int(max(0, x0 - v / 2.)) + y0 = int(max(0, y0 - v / 2.)) + x1 = min(w, x0 + v) + y1 = min(h, y0 + v) + + xy = (x0, y0, x1, y1) + color = (125, 123, 114) + # color = (0, 0, 0) + img = img.copy() + PIL.ImageDraw.Draw(img).rectangle(xy, color) + return img + + +def augment_list(): + l = [ + (AutoContrast, 0, 1), + (Brightness, 0.05, 0.95), + (Color, 0.05, 0.95), + (Contrast, 0.05, 0.95), + (Equalize, 0, 1), + (Identity, 0, 1), + (Posterize, 4, 8), + # (Rotate, -30, 30), + (Sharpness, 0.05, 0.95), + # (ShearX, -0.3, 0.3), + # (ShearY, -0.3, 0.3), + (Solarize, 0, 256), + # (TranslateX, -0.3, 0.3), + # (TranslateY, -0.3, 0.3) + ] + return l + + +class RandAugment: + def __init__(self, n, m): + self.n = n + self.m = m # [0, 30] in fixmatch, deprecated. + self.augment_list = augment_list() + + + def __call__(self, img, cutout=True): + ops = random.choices(self.augment_list, k=self.n) + for op, min_val, max_val in ops: + val = min_val + float(max_val - min_val)*random.random() + img = op(img, val) + if cutout: + cutout_val = random.random() * 0.5 + img = Cutout(img, cutout_val) #for fixmatch + return img + + +if __name__ == '__main__': + # randaug = RandAugment(3,5) + # print(randaug) + # for item in randaug.augment_list: + # print(item) + import os + + os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' + img = PIL.Image.open('./u.jpg') + randaug = RandAugment(3,6) + img = randaug(img) + import matplotlib + from matplotlib import pyplot as plt + plt.imshow(img) + plt.show() \ No newline at end of file diff --git a/data/refmatte_dataset.py b/data/refmatte_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ea66a6ed590ff49611ed277839bd4d38620917 --- /dev/null +++ b/data/refmatte_dataset.py @@ -0,0 +1,418 @@ +import os +import torch +import numpy as np +import cv2 +from torch.utils.data import Dataset +from torchvision import transforms +import random +import imgaug.augmenters as iaa +import numbers +import math + + +def random_interp(): + return np.random.choice([cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]) + +class RandomAffine(object): + """ + Random affine translation + """ + def __init__(self, degrees, translate=None, scale=None, shear=None, flip=None, resample=False, fillcolor=0): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError("If degrees is a single number, it must be positive.") + self.degrees = (-degrees, degrees) + else: + assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ + "degrees should be a list or tuple and it must be of length 2." + self.degrees = degrees + + if translate is not None: + assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ + "translate should be a list or tuple and it must be of length 2." + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + + if scale is not None: + assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ + "scale should be a list or tuple and it must be of length 2." + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + if isinstance(shear, numbers.Number): + if shear < 0: + raise ValueError("If shear is a single number, it must be positive.") + self.shear = (-shear, shear) + else: + assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ + "shear should be a list or tuple and it must be of length 2." + self.shear = shear + else: + self.shear = shear + + self.resample = resample + self.fillcolor = fillcolor + self.flip = flip + + @staticmethod + def get_params(degrees, translate, scale_ranges, shears, flip, img_size): + """Get parameters for affine transformation + + Returns: + sequence: params to be passed to the affine transformation + """ + angle = random.uniform(degrees[0], degrees[1]) + if translate is not None: + max_dx = translate[0] * img_size[0] + max_dy = translate[1] * img_size[1] + translations = (np.round(random.uniform(-max_dx, max_dx)), + np.round(random.uniform(-max_dy, max_dy))) + else: + translations = (0, 0) + + if scale_ranges is not None: + scale = (random.uniform(scale_ranges[0], scale_ranges[1]), + random.uniform(scale_ranges[0], scale_ranges[1])) + else: + scale = (1.0, 1.0) + + if shears is not None: + shear = random.uniform(shears[0], shears[1]) + else: + shear = 0.0 + + if flip is not None: + flip = (np.random.rand(2) < flip).astype(np.int32) * 2 - 1 + + return angle, translations, scale, shear, flip + + def __call__(self, sample): + fg, alpha = sample['fg'], sample['alpha'] + rows, cols, ch = fg.shape + if np.maximum(rows, cols) < 1024: + params = self.get_params((0, 0), self.translate, self.scale, self.shear, self.flip, fg.size) + else: + params = self.get_params(self.degrees, self.translate, self.scale, self.shear, self.flip, fg.size) + + center = (cols * 0.5 + 0.5, rows * 0.5 + 0.5) + M = self._get_inverse_affine_matrix(center, *params) + M = np.array(M).reshape((2, 3)) + + fg = cv2.warpAffine(fg, M, (cols, rows), flags=random_interp() + cv2.WARP_INVERSE_MAP) + alpha = cv2.warpAffine(alpha, M, (cols, rows), flags=random_interp() + cv2.WARP_INVERSE_MAP) + + sample['fg'], sample['alpha'] = fg, alpha + + return sample + + @ staticmethod + def _get_inverse_affine_matrix(center, angle, translate, scale, shear, flip): + # Helper method to compute inverse matrix for affine transformation + + # As it is explained in PIL.Image.rotate + # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1 + # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] + # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] + # RSS is rotation with scale and shear matrix + # It is different from the original function in torchvision + # The order are changed to flip -> scale -> rotation -> shear + # x and y have different scale factors + # RSS(shear, a, scale, f) = [ cos(a + shear)*scale_x*f -sin(a + shear)*scale_y 0] + # [ sin(a)*scale_x*f cos(a)*scale_y 0] + # [ 0 0 1] + # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 + + angle = math.radians(angle) + shear = math.radians(shear) + scale_x = 1.0 / scale[0] * flip[0] + scale_y = 1.0 / scale[1] * flip[1] + + # Inverted rotation matrix with scale and shear + d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle) + matrix = [ + math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0, + -math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0 + ] + matrix = [m / d for m in matrix] + + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1]) + matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1]) + + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + matrix[2] += center[0] + matrix[5] += center[1] + + return matrix + + +class GenTrimap(object): + def __init__(self): + self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,100)] + + def __call__(self, sample): + alpha = sample['alpha'] + h, w = alpha.shape + + max_kernel_size = max(30, int((min(h,w) / 2048) * 30)) + + ### generate trimap + fg_mask = (alpha + 1e-5).astype(np.int32).astype(np.uint8) + bg_mask = (1 - alpha + 1e-5).astype(np.int32).astype(np.uint8) + fg_mask = cv2.erode(fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) + bg_mask = cv2.erode(bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) + + trimap = np.ones_like(alpha) * 128 + trimap[fg_mask == 1] = 255 + trimap[bg_mask == 1] = 0 + + trimap = cv2.resize(trimap, (w,h), interpolation=cv2.INTER_NEAREST) + sample['trimap'] = trimap + + return sample + + +class RandomCrop(object): + """ + Crop randomly the image in a sample, retain the center 1/4 images, and resize to 'output_size' + + :param output_size (tuple or int): Desired output size. If int, square crop + is made. + """ + + def __init__(self, output_size=(1024, 1024)): + assert isinstance(output_size, (int, tuple)) + if isinstance(output_size, int): + self.output_size = (output_size, output_size) + else: + assert len(output_size) == 2 + self.output_size = output_size + self.margin = output_size[0] // 2 + + def __call__(self, sample): + fg, alpha, trimap, name = sample['fg'], sample['alpha'], sample['trimap'], sample['image_name'] + bg = sample['bg'] + h, w = trimap.shape + bg = cv2.resize(bg, (w, h), interpolation=random_interp()) + if w < self.output_size[0]+1 or h < self.output_size[1]+1: + ratio = 1.1*self.output_size[0]/h if h < w else 1.1*self.output_size[1]/w + # self.logger.warning("Size of {} is {}.".format(name, (h, w))) + while h < self.output_size[0]+1 or w < self.output_size[1]+1: + fg = cv2.resize(fg, (int(w*ratio), int(h*ratio)), interpolation=random_interp()) + alpha = cv2.resize(alpha, (int(w*ratio), int(h*ratio)), + interpolation=random_interp()) + trimap = cv2.resize(trimap, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST) + bg = cv2.resize(bg, (int(w*ratio), int(h*ratio)), interpolation=random_interp()) + h, w = trimap.shape + small_trimap = cv2.resize(trimap, (w//4, h//4), interpolation=cv2.INTER_NEAREST) + unknown_list = list(zip(*np.where(small_trimap[self.margin//4:(h-self.margin)//4, + self.margin//4:(w-self.margin)//4] == 128))) + unknown_num = len(unknown_list) + if len(unknown_list) < 10: + left_top = (np.random.randint(0, h-self.output_size[0]+1), np.random.randint(0, w-self.output_size[1]+1)) + else: + idx = np.random.randint(unknown_num) + left_top = (unknown_list[idx][0]*4, unknown_list[idx][1]*4) + + fg_crop = fg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:] + alpha_crop = alpha[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]] + bg_crop = bg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:] + trimap_crop = trimap[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]] + + if len(np.where(trimap==128)[0]) == 0: + fg_crop = cv2.resize(fg, self.output_size[::-1], interpolation=random_interp()) + alpha_crop = cv2.resize(alpha, self.output_size[::-1], interpolation=random_interp()) + trimap_crop = cv2.resize(trimap, self.output_size[::-1], interpolation=cv2.INTER_NEAREST) + bg_crop = cv2.resize(bg, self.output_size[::-1], interpolation=random_interp()) + + sample.update({'fg': fg_crop, 'alpha': alpha_crop, 'trimap': trimap_crop, 'bg': bg_crop}) + return sample + + +class Composite_Seg(object): + def __call__(self, sample): + fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha'] + fg[fg < 0 ] = 0 + fg[fg > 255] = 255 + image = fg + sample['image'] = image + return sample + + +class ToTensor(object): + """ + Convert ndarrays in sample to Tensors with normalization. + """ + def __init__(self, phase="test", real_world_aug = False): + # self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1) + # self.std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) + self.mean = torch.tensor([0.0, 0.0, 0.0]).view(3,1,1) + self.std = torch.tensor([1.0, 1.0, 1.0]).view(3,1,1) + self.phase = phase + if real_world_aug: + self.RWA = iaa.SomeOf((1, None), [ + iaa.LinearContrast((0.6, 1.4)), + iaa.JpegCompression(compression=(0, 60)), + iaa.GaussianBlur(sigma=(0.0, 3.0)), + iaa.AdditiveGaussianNoise(scale=(0, 0.1*255)) + ], random_order=True) + else: + self.RWA = None + + def get_box_from_alpha(self, alpha_final): + bi_mask = np.zeros_like(alpha_final) + bi_mask[alpha_final>0.5] = 1 + #bi_mask[alpha_final<=0.5] = 0 + fg_set = np.where(bi_mask != 0) + if len(fg_set[1]) == 0 or len(fg_set[0]) == 0: + x_min = random.randint(1, 511) + x_max = random.randint(1, 511) + x_min + y_min = random.randint(1, 511) + y_max = random.randint(1, 511) + y_min + else: + x_min = np.min(fg_set[1]) + x_max = np.max(fg_set[1]) + y_min = np.min(fg_set[0]) + y_max = np.max(fg_set[0]) + bbox = np.array([x_min, y_min, x_max, y_max]) + #cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0,255,0), 2) + #cv2.imwrite('../outputs/test.jpg', image) + #cv2.imwrite('../outputs/test_gt.jpg', alpha_single) + return bbox + + def __call__(self, sample): + # convert GBR images to RGB + image, alpha, trimap = sample['image'][:,:,::-1], sample['alpha'], sample['trimap'] + + alpha[alpha < 0 ] = 0 + alpha[alpha > 1] = 1 + + bbox = self.get_box_from_alpha(alpha) + + if self.phase == 'train' and self.RWA is not None and np.random.rand() < 0.5: + image[image > 255] = 255 + image[image < 0] = 0 + image = np.round(image).astype(np.uint8) + image = np.expand_dims(image, axis=0) + image = self.RWA(images=image) + image = image[0, ...] + + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + image = image.transpose((2, 0, 1)).astype(np.float32) + alpha = np.expand_dims(alpha.astype(np.float32), axis=0) + trimap[trimap < 85] = 0 + trimap[trimap >= 170] = 2 + trimap[trimap >= 85] = 1 + #image = cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255,0,0), 3) + #cv2.imwrite(os.path.join('outputs', 'img_bbox.png'), image.astype('uint8')) + # normalize image + image /= 255. + + if self.phase == "train": + # convert GBR images to RGB + fg = sample['fg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255. + sample['fg'] = torch.from_numpy(fg).sub_(self.mean).div_(self.std) + bg = sample['bg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255. + sample['bg'] = torch.from_numpy(bg).sub_(self.mean).div_(self.std) + del sample['image_name'] + + sample['boxes'] = torch.from_numpy(bbox).to(torch.float)[None,...] + + sample['image'], sample['alpha'], sample['trimap'] = \ + torch.from_numpy(image), torch.from_numpy(alpha), torch.from_numpy(trimap).to(torch.long) + sample['image'] = sample['image'].sub_(self.mean).div_(self.std) + sample['trimap'] = sample['trimap'][None,...].float() + + return sample + + +class RefMatteData(Dataset): + def __init__( + self, + data_root_path, + num_ratio = 0.34, + ): + self.data_root_path = data_root_path + self.num_ratio = num_ratio + + self.rim_img = [os.path.join(data_root_path, name) for name in sorted(os.listdir(data_root_path))] + self.rim_pha = [os.path.join(data_root_path.replace('img', 'mask'), name) for name in sorted(os.listdir(data_root_path.replace('img', 'mask')))] + self.rim_num = len(self.rim_pha) + + self.transform_spd = transforms.Compose([ + RandomAffine(degrees=30, scale=[0.8, 1.5], shear=10, flip=0.5), + GenTrimap(), + RandomCrop((1024, 1024)), + Composite_Seg(), + ToTensor(phase="train", real_world_aug=False) + ]) + + def __getitem__(self, idx): + if self.num_ratio is not None: + if self.num_ratio < 1.0 or idx >= self.rim_num: + idx = np.random.randint(0, self.rim_num) + alpha = cv2.imread(self.rim_pha[idx % self.rim_num], 0).astype(np.float32)/255 + alpha_img_name = self.rim_pha[idx % self.rim_num].split('/')[-1] + fg_img_name = alpha_img_name[:-6] + '.jpg' + + fg = cv2.imread(os.path.join(self.data_root_path, fg_img_name)) + + if np.random.rand() < 0.25: + fg = cv2.resize(fg, (1280, 1280), interpolation=random_interp()) + alpha = cv2.resize(alpha, (1280, 1280), interpolation=random_interp()) + + image_name = alpha_img_name # os.path.split(self.rim_img[idx % self.rim_num])[-1] + sample = {'fg': fg, 'alpha': alpha, 'bg': fg, 'image_name': image_name} + sample = self.transform_spd(sample) + + converted_sample = { + 'image': sample['image'], + 'trimap': sample['trimap'] / 2.0, + 'alpha': sample['alpha'], + 'bbox': sample['boxes'], + 'dataset_name': 'RefMatte', + 'multi_fg': False, + } + return converted_sample + + def __len__(self): + if self.num_ratio is not None: + return int(self.rim_num * self.num_ratio) # 112506 * 0.34 = 38252 (COCONut_num-38251 + 1) + else: + return self.rim_num # 112506 + + + +if __name__ == '__main__': + dataset = RefMatteData( + data_root_path = '/data/my_path_b/public_data/data/matting/RefMatte/RefMatte/train/img', + num_ratio=0.34, + ) + data = dataset[0] + ''' + fg torch.Size([3, 1024, 1024]) tensor(-2.1179) tensor(2.6400) + alpha torch.Size([1, 1024, 1024]) tensor(0.) tensor(1.) + bg torch.Size([3, 1024, 1024]) tensor(-2.1179) tensor(2.6400) + trimap torch.Size([1, 1024, 1024]) 0.0 or 1.0 or 2.0 + image torch.Size([3, 1024, 1024]) tensor(-2.1179) tensor(2.6400) + boxes torch.Size([1, 4]) tensor(72.) tensor(676.) 0.0~1024.0 + + COCONut: + image torch.Size([3, 1024, 1024]) tensor(0.0006) tensor(0.9991) + trimap torch.Size([1, 1024, 1024]) 0.0 or 0.5 or 1.0 + alpha torch.Size([1, 1024, 1024]) tensor(0.) tensor(1.) + bbox torch.Size([1, 4]) tensor(0.) tensor(590.) + dataset_name: 'COCONut' + ''' + for key, val in data.items(): + if isinstance(val, torch.Tensor): + print(key, val.shape, torch.min(val), torch.max(val)) + else: + print(key, val.shape) \ No newline at end of file diff --git a/engine/__init__.py b/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8a11d96bd453ed20783a091ec1110a09900cbac --- /dev/null +++ b/engine/__init__.py @@ -0,0 +1 @@ +from .mattingtrainer import MattingTrainer diff --git a/engine/hooks.py b/engine/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..eb9c486b488800c63e1c5a5f6453eccfc8fa156f --- /dev/null +++ b/engine/hooks.py @@ -0,0 +1,52 @@ +import inspect +import detectron2.utils.comm as comm +from detectron2.engine import EvalHook as _EvalHook +from detectron2.evaluation.testing import flatten_results_dict + + +class EvalHook(_EvalHook): + def __init__(self, eval_period, eval_function): + super().__init__(eval_period, eval_function) + func_args = inspect.getfullargspec(eval_function).args + assert {"final_iter", "next_iter"}.issubset(set(func_args)), ( + f"Eval function must have either 'final_iter' or 'next_iter' as an argument." + f"Got {func_args} instead." + ) + + def _do_eval(self, final_iter=False, next_iter=0): + results = self._func(final_iter=final_iter, next_iter=next_iter) + + if results: + assert isinstance( + results, dict + ), "Eval function must return a dict. Got {} instead.".format(results) + + flattened_results = flatten_results_dict(results) + for k, v in flattened_results.items(): + try: + v = float(v) + except Exception as e: + raise ValueError( + "[EvalHook] eval_function should return a nested dict of float. " + "Got '{}: {}' instead.".format(k, v) + ) from e + self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False) + + # Evaluation may take different time among workers. + # A barrier make them start the next iteration together. + comm.synchronize() + + def after_step(self): + next_iter = self.trainer.iter + 1 + if self._period > 0 and next_iter % self._period == 0: + # do the last eval in after_train + if next_iter != self.trainer.max_iter: + self._do_eval(next_iter=next_iter) + + def after_train(self): + # This condition is to prevent the eval from running after a failed training + if self.trainer.iter + 1 >= self.trainer.max_iter: + self._do_eval(final_iter=True) + # func is likely a closure that holds reference to the trainer + # therefore we clean it to avoid circular reference in the end + del self._func \ No newline at end of file diff --git a/engine/mattingtrainer.py b/engine/mattingtrainer.py new file mode 100644 index 0000000000000000000000000000000000000000..13849c706b9e9644625743803a6dd1c4d83f86dd --- /dev/null +++ b/engine/mattingtrainer.py @@ -0,0 +1,171 @@ +from detectron2.engine import AMPTrainer +import torch +import time +import logging + +logger = logging.getLogger("detectron2") + +import typing +from collections import defaultdict +import tabulate +from torch import nn + + +def parameter_count(model: nn.Module, trainable_only: bool = False) -> typing.DefaultDict[str, int]: + """ + Count parameters of a model and its submodules. + + Args: + model: a torch module + + Returns: + dict (str-> int): the key is either a parameter name or a module name. + The value is the number of elements in the parameter, or in all + parameters of the module. The key "" corresponds to the total + number of parameters of the model. + """ + r = defaultdict(int) + for name, prm in model.named_parameters(): + if trainable_only: + if not prm.requires_grad: + continue + size = prm.numel() + name = name.split(".") + for k in range(0, len(name) + 1): + prefix = ".".join(name[:k]) + r[prefix] += size + return r + + +def parameter_count_table( + model: nn.Module, max_depth: int = 3, trainable_only: bool = False +) -> str: + """ + Format the parameter count of the model (and its submodules or parameters) + in a nice table. It looks like this: + + :: + + | name | #elements or shape | + |:--------------------------------|:---------------------| + | model | 37.9M | + | backbone | 31.5M | + | backbone.fpn_lateral3 | 0.1M | + | backbone.fpn_lateral3.weight | (256, 512, 1, 1) | + | backbone.fpn_lateral3.bias | (256,) | + | backbone.fpn_output3 | 0.6M | + | backbone.fpn_output3.weight | (256, 256, 3, 3) | + | backbone.fpn_output3.bias | (256,) | + | backbone.fpn_lateral4 | 0.3M | + | backbone.fpn_lateral4.weight | (256, 1024, 1, 1) | + | backbone.fpn_lateral4.bias | (256,) | + | backbone.fpn_output4 | 0.6M | + | backbone.fpn_output4.weight | (256, 256, 3, 3) | + | backbone.fpn_output4.bias | (256,) | + | backbone.fpn_lateral5 | 0.5M | + | backbone.fpn_lateral5.weight | (256, 2048, 1, 1) | + | backbone.fpn_lateral5.bias | (256,) | + | backbone.fpn_output5 | 0.6M | + | backbone.fpn_output5.weight | (256, 256, 3, 3) | + | backbone.fpn_output5.bias | (256,) | + | backbone.top_block | 5.3M | + | backbone.top_block.p6 | 4.7M | + | backbone.top_block.p7 | 0.6M | + | backbone.bottom_up | 23.5M | + | backbone.bottom_up.stem | 9.4K | + | backbone.bottom_up.res2 | 0.2M | + | backbone.bottom_up.res3 | 1.2M | + | backbone.bottom_up.res4 | 7.1M | + | backbone.bottom_up.res5 | 14.9M | + | ...... | ..... | + + Args: + model: a torch module + max_depth (int): maximum depth to recursively print submodules or + parameters + + Returns: + str: the table to be printed + """ + count: typing.DefaultDict[str, int] = parameter_count(model, trainable_only) + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. + param_shape: typing.Dict[str, typing.Tuple] = { + k: tuple(v.shape) for k, v in model.named_parameters() + } + + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. + table: typing.List[typing.Tuple] = [] + + def format_size(x: int) -> str: + if x > 1e8: + return "{:.1f}G".format(x / 1e9) + if x > 1e5: + return "{:.1f}M".format(x / 1e6) + if x > 1e2: + return "{:.1f}K".format(x / 1e3) + return str(x) + + def fill(lvl: int, prefix: str) -> None: + if lvl >= max_depth: + return + for name, v in count.items(): + if name.count(".") == lvl and name.startswith(prefix): + indent = " " * (lvl + 1) + if name in param_shape: + table.append((indent + name, indent + str(param_shape[name]))) + else: + table.append((indent + name, indent + format_size(v))) + fill(lvl + 1, name + ".") + + table.append(("model", format_size(count.pop("")))) + fill(0, "") + + old_ws = tabulate.PRESERVE_WHITESPACE + tabulate.PRESERVE_WHITESPACE = True + tab = tabulate.tabulate(table, headers=["name", "#elements or shape"], tablefmt="pipe") + tabulate.PRESERVE_WHITESPACE = old_ws + return tab + + +def cycle(iterable): + while True: + for x in iterable: + yield x + +class MattingTrainer(AMPTrainer): + def __init__(self, model, data_loader, optimizer, grad_scaler=None): + super().__init__(model, data_loader, optimizer, grad_scaler=None) + self.data_loader_iter = iter(cycle(self.data_loader)) + + # print model parameters + logger.info("All parameters: \n" + parameter_count_table(model)) + logger.info("Trainable parameters: \n" + parameter_count_table(model, trainable_only=True, max_depth=8)) + + def run_step(self): + """ + Implement the AMP training logic. + """ + assert self.model.training, "[AMPTrainer] model was changed to eval mode!" + assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" + from torch.cuda.amp import autocast + + #matting pass + start = time.perf_counter() + data = next(self.data_loader_iter) + data_time = time.perf_counter() - start + + with autocast(): + loss_dict = self.model(data) + if isinstance(loss_dict, torch.Tensor): + losses = loss_dict + loss_dict = {"total_loss": loss_dict} + else: + losses = sum(loss_dict.values()) + + self.optimizer.zero_grad() + self.grad_scaler.scale(losses).backward() + + self._write_metrics(loss_dict, data_time) + + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() \ No newline at end of file diff --git a/modeling/__init__.py b/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b89f6584f82a9ab1cf77d1b4032a8e5829dcda7 --- /dev/null +++ b/modeling/__init__.py @@ -0,0 +1,5 @@ +from .backbone import * +from .criterion import * +from .decoder import * +from .meta_arch import * +from .semantic_enhanced_matting import * \ No newline at end of file diff --git a/modeling/__pycache__/__init__.cpython-38.pyc b/modeling/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc4f753be9b7189152be5139a7df3a861069154a Binary files /dev/null and b/modeling/__pycache__/__init__.cpython-38.pyc differ diff --git a/modeling/backbone/__init__.py b/modeling/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9420c05c3a5ea747b4a6e884ae079b82bae9fa39 --- /dev/null +++ b/modeling/backbone/__init__.py @@ -0,0 +1,2 @@ +from .backbone import * +from .vit import * \ No newline at end of file diff --git a/modeling/backbone/__pycache__/__init__.cpython-38.pyc b/modeling/backbone/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa50dd88a5d5a61448f0148ebdec757140d4dbeb Binary files /dev/null and b/modeling/backbone/__pycache__/__init__.cpython-38.pyc differ diff --git a/modeling/backbone/__pycache__/backbone.cpython-38.pyc b/modeling/backbone/__pycache__/backbone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..035d55d7a631c2437b53e88920227da03d3bdc04 Binary files /dev/null and b/modeling/backbone/__pycache__/backbone.cpython-38.pyc differ diff --git a/modeling/backbone/__pycache__/utils.cpython-38.pyc b/modeling/backbone/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa110dc5bbdcdc38091d97cccdabff98f3257465 Binary files /dev/null and b/modeling/backbone/__pycache__/utils.cpython-38.pyc differ diff --git a/modeling/backbone/__pycache__/vit.cpython-38.pyc b/modeling/backbone/__pycache__/vit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa38b7d5676eb701f478264a628e900434c9b590 Binary files /dev/null and b/modeling/backbone/__pycache__/vit.cpython-38.pyc differ diff --git a/modeling/backbone/backbone.py b/modeling/backbone/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c765a6b38542f66cae55216bba697a6626d128 --- /dev/null +++ b/modeling/backbone/backbone.py @@ -0,0 +1,74 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from abc import ABCMeta, abstractmethod +from typing import Dict +import torch.nn as nn + +from detectron2.layers import ShapeSpec + +__all__ = ["Backbone"] + + +class Backbone(nn.Module, metaclass=ABCMeta): + """ + Abstract base class for network backbones. + """ + + def __init__(self): + """ + The `__init__` method of any subclass can specify its own set of arguments. + """ + super().__init__() + + @abstractmethod + def forward(self): + """ + Subclasses must override this method, but adhere to the same return type. + + Returns: + dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor + """ + pass + + @property + def size_divisibility(self) -> int: + """ + Some backbones require the input height and width to be divisible by a + specific integer. This is typically true for encoder / decoder type networks + with lateral connection (e.g., FPN) for which feature maps need to match + dimension in the "bottom up" and "top down" paths. Set to 0 if no specific + input size divisibility is required. + """ + return 0 + + @property + def padding_constraints(self) -> Dict[str, int]: + """ + This property is a generalization of size_divisibility. Some backbones and training + recipes require specific padding constraints, such as enforcing divisibility by a specific + integer (e.g., FPN) or padding to a square (e.g., ViTDet with large-scale jitter + in :paper:vitdet). `padding_constraints` contains these optional items like: + { + "size_divisibility": int, + "square_size": int, + # Future options are possible + } + `size_divisibility` will read from here if presented and `square_size` indicates the + square padding size if `square_size` > 0. + + TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints + could be generalized as TypedDict (Python 3.8+) to support more types in the future. + """ + return {} + + def output_shape(self): + """ + Returns: + dict[str->ShapeSpec] + """ + # this is a backward-compatible default + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } diff --git a/modeling/backbone/utils.py b/modeling/backbone/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2b89a4c3fbe079a77fd0cef947cf9ada787fc55d --- /dev/null +++ b/modeling/backbone/utils.py @@ -0,0 +1,186 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = [ + "window_partition", + "window_unpartition", + "add_decomposed_rel_pos", + "get_abs_pos", + "PatchEmbed", +] + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size, k_size, rel_pos): + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +def get_abs_pos(abs_pos, has_cls_token, hw): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + h, w = hw + if has_cls_token: + abs_pos = abs_pos[:, 1:] + xy_num = abs_pos.shape[1] + size = int(math.sqrt(xy_num)) + assert size * size == xy_num + + if size != h or size != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ) + + return new_abs_pos.permute(0, 2, 3, 1) + else: + return abs_pos.reshape(1, h, w, -1) + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768 + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x): + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/modeling/backbone/vit.py b/modeling/backbone/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..f98efd3dbf386919afe652984d9a2b9f89a84ab5 --- /dev/null +++ b/modeling/backbone/vit.py @@ -0,0 +1,404 @@ +import logging +import math +import fvcore.nn.weight_init as weight_init +import torch +import torch.nn as nn +from torch.nn import functional as F +from detectron2.layers import CNNBlockBase, Conv2d, get_norm +from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous +from fairscale.nn.checkpoint import checkpoint_wrapper +from timm.models.layers import DropPath, Mlp, trunc_normal_ +from .backbone import Backbone +from .utils import ( + PatchEmbed, + add_decomposed_rel_pos, + get_abs_pos, + window_partition, + window_unpartition, +) + +logger = logging.getLogger(__name__) + + +__all__ = ["ViT"] + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + use_rel_pos=False, + rel_pos_zero_init=True, + input_size=None, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + if not rel_pos_zero_init: + trunc_normal_(self.rel_pos_h, std=0.02) + trunc_normal_(self.rel_pos_w, std=0.02) + + def forward(self, x): + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + +class ResBottleneckBlock(CNNBlockBase): + """ + The standard bottleneck residual block without the last activation layer. + It contains 3 conv layers with kernels 1x1, 3x3, 1x1. + """ + + def __init__( + self, + in_channels, + out_channels, + bottleneck_channels, + norm="LN", + act_layer=nn.GELU, + conv_kernels=3, + conv_paddings=1, + ): + """ + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + act_layer (callable): activation for all conv layers. + """ + super().__init__(in_channels, out_channels, 1) + + self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False) + self.norm1 = get_norm(norm, bottleneck_channels) + self.act1 = act_layer() + + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + conv_kernels, + padding=conv_paddings, + bias=False, + ) + self.norm2 = get_norm(norm, bottleneck_channels) + self.act2 = act_layer() + + self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False) + self.norm3 = get_norm(norm, out_channels) + + for layer in [self.conv1, self.conv2, self.conv3]: + weight_init.c2_msra_fill(layer) + for layer in [self.norm1, self.norm2]: + layer.weight.data.fill_(1.0) + layer.bias.data.zero_() + # zero init last norm layer. + self.norm3.weight.data.zero_() + self.norm3.bias.data.zero_() + + def forward(self, x): + out = x + for layer in self.children(): + out = layer(out) + + out = x + out + return out + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + use_rel_pos=False, + rel_pos_zero_init=True, + window_size=0, + use_cc_attn = False, + use_residual_block=False, + use_convnext_block=False, + input_size=None, + res_conv_kernel_size=3, + res_conv_padding=1, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then not + use window attention. + use_residual_block (bool): If True, use a residual block after the MLP block. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) + + self.window_size = window_size + + self.use_residual_block = use_residual_block + if use_residual_block: + # Use a residual block with bottleneck channel as dim // 2 + self.residual = ResBottleneckBlock( + in_channels=dim, + out_channels=dim, + bottleneck_channels=dim // 2, + norm="LN", + act_layer=act_layer, + conv_kernels=res_conv_kernel_size, + conv_paddings=res_conv_padding, + ) + self.use_convnext_block = use_convnext_block + if use_convnext_block: + self.convnext = ConvNextBlock(dim = dim) + + if use_cc_attn: + self.attn = CrissCrossAttention(dim) + + + def forward(self, x): + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + if self.use_residual_block: + x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + if self.use_convnext_block: + x = self.convnext(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + return x + + +class ViT(Backbone): + """ + This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. + "Exploring Plain Vision Transformer Backbones for Object Detection", + https://arxiv.org/abs/2203.16527 + """ + + def __init__( + self, + img_size=1024, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + use_abs_pos=True, + use_rel_pos=False, + rel_pos_zero_init=True, + window_size=0, + window_block_indexes=(), + residual_block_indexes=(), + use_act_checkpoint=False, + pretrain_img_size=224, + pretrain_use_cls_token=True, + out_feature="last_feat", + res_conv_kernel_size=3, + res_conv_padding=1, + ): + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + window_block_indexes (list): Indexes for blocks using window attention. + residual_block_indexes (list): Indexes for blocks using conv propagation. + use_act_checkpoint (bool): If True, use activation checkpointing. + pretrain_img_size (int): input image size for pretraining models. + pretrain_use_cls_token (bool): If True, pretrainig models use class token. + out_feature (str): name of the feature from the last block. + """ + super().__init__() + self.pretrain_use_cls_token = pretrain_use_cls_token + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size) + num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim)) + else: + self.pos_embed = None + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i in window_block_indexes else 0, + use_residual_block=i in residual_block_indexes, + input_size=(img_size // patch_size, img_size // patch_size), + res_conv_kernel_size=res_conv_kernel_size, + res_conv_padding=res_conv_padding, + ) + if use_act_checkpoint: + block = checkpoint_wrapper(block) + self.blocks.append(block) + + self._out_feature_channels = {out_feature: embed_dim} + self._out_feature_strides = {out_feature: patch_size} + self._out_features = [out_feature] + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + get_abs_pos( + self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2]) + ) + + for blk in self.blocks: + x = blk(x) + + outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)} + + return outputs['last_feat'] \ No newline at end of file diff --git a/modeling/criterion/__init__.py b/modeling/criterion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f09984281bdc1a8338f5a403e78f91bb8b3b172e --- /dev/null +++ b/modeling/criterion/__init__.py @@ -0,0 +1 @@ +from .matting_criterion import MattingCriterion \ No newline at end of file diff --git a/modeling/criterion/__pycache__/__init__.cpython-38.pyc b/modeling/criterion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef71b1ae2393b3dc9f6e03e7a8467b8854fd7078 Binary files /dev/null and b/modeling/criterion/__pycache__/__init__.cpython-38.pyc differ diff --git a/modeling/criterion/__pycache__/matting_criterion.cpython-38.pyc b/modeling/criterion/__pycache__/matting_criterion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a624eda9f841b05af01fd3b91242718b5bcc6f0 Binary files /dev/null and b/modeling/criterion/__pycache__/matting_criterion.cpython-38.pyc differ diff --git a/modeling/criterion/matting_criterion.py b/modeling/criterion/matting_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..8f47a72e82257132b5f2100134bfc0b7696c2dbf --- /dev/null +++ b/modeling/criterion/matting_criterion.py @@ -0,0 +1,271 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import defaultdict + + +class MattingCriterion(nn.Module): + def __init__( + self, + *, + losses, + image_size = 1024, + ): + super(MattingCriterion, self).__init__() + self.losses = losses + self.image_size = image_size + + def loss_gradient_penalty(self, sample_map, preds, targets): + + #sample_map for unknown area + if torch.sum(sample_map) == 0: + scale = 0 + else: + scale = sample_map.shape[0] * (self.image_size ** 2) / torch.sum(sample_map) + + #gradient in x + sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type()) + delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1) + delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1) + + #gradient in y + sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type()) + delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1) + delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1) + + #loss + loss = (F.l1_loss(delta_pred_x * sample_map, delta_gt_x * sample_map) * scale + \ + F.l1_loss(delta_pred_y * sample_map, delta_gt_y * sample_map) * scale + \ + 0.01 * torch.mean(torch.abs(delta_pred_x * sample_map)) * scale + \ + 0.01 * torch.mean(torch.abs(delta_pred_y * sample_map)) * scale) + + return dict(loss_gradient_penalty=loss) + + def loss_pha_laplacian(self, preds, targets): + loss = laplacian_loss(preds, targets) + return dict(loss_pha_laplacian=loss) + + def unknown_l1_loss(self, sample_map, preds, targets): + + if torch.sum(sample_map) == 0: + scale = 0 + else: + scale = sample_map.shape[0] * (self.image_size ** 2) / torch.sum(sample_map) + # scale = 1 + + loss = F.l1_loss(preds * sample_map, targets * sample_map) * scale + + return dict(unknown_l1_loss=loss) + + def known_l1_loss(self, sample_map, preds, targets): + new_sample_map = torch.zeros_like(sample_map) + new_sample_map[sample_map==0] = 1 + + if torch.sum(new_sample_map) == 0: + scale = 0 + else: + scale = new_sample_map.shape[0] * (self.image_size ** 2) / torch.sum(new_sample_map) + # scale = 1 + + loss = F.l1_loss(preds * new_sample_map, targets * new_sample_map) * scale + + return dict(known_l1_loss=loss) + + def get_loss(self, k, sample_map, preds, targets): + if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty': + losses = getattr(self, k)(sample_map, preds, targets) + else: + losses = getattr(self, k)(preds, targets) + assert len(list(losses.keys())) == 1 + return losses[list(losses.keys())[0]] + + def forward(self, sample_map, preds, targets, batch_weight=None): + losses = {i: torch.tensor(0.0, device=sample_map.device) for i in self.losses} + for k in self.losses: + if batch_weight is None: + losses[k] += self.get_loss(k, sample_map, preds, targets) + else: + for i, loss_weight in enumerate(batch_weight): + if loss_weight == -1.0 and k != 'known_l1_loss': + continue + else: + losses[k] += self.get_loss(k, sample_map[i: i + 1], preds[i: i + 1], targets[i: i + 1]) * abs(loss_weight) + return losses + + +#-----------------Laplacian Loss-------------------------# +def laplacian_loss(pred, true, max_levels=5): + kernel = gauss_kernel(device=pred.device, dtype=pred.dtype) + pred_pyramid = laplacian_pyramid(pred, kernel, max_levels) + true_pyramid = laplacian_pyramid(true, kernel, max_levels) + loss = 0 + for level in range(max_levels): + loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level]) + return loss / max_levels + +def laplacian_pyramid(img, kernel, max_levels): + current = img + pyramid = [] + for _ in range(max_levels): + current = crop_to_even_size(current) + down = downsample(current, kernel) + up = upsample(down, kernel) + diff = current - up + pyramid.append(diff) + current = down + return pyramid + +def gauss_kernel(device='cpu', dtype=torch.float32): + kernel = torch.tensor([[1, 4, 6, 4, 1], + [4, 16, 24, 16, 4], + [6, 24, 36, 24, 6], + [4, 16, 24, 16, 4], + [1, 4, 6, 4, 1]], device=device, dtype=dtype) + kernel /= 256 + kernel = kernel[None, None, :, :] + return kernel + +def gauss_convolution(img, kernel): + B, C, H, W = img.shape + img = img.reshape(B * C, 1, H, W) + img = F.pad(img, (2, 2, 2, 2), mode='reflect') + img = F.conv2d(img, kernel) + img = img.reshape(B, C, H, W) + return img + +def downsample(img, kernel): + img = gauss_convolution(img, kernel) + img = img[:, :, ::2, ::2] + return img + +def upsample(img, kernel): + B, C, H, W = img.shape + out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype) + out[:, :, ::2, ::2] = img * 4 + out = gauss_convolution(out, kernel) + return out + +def crop_to_even_size(img): + H, W = img.shape[2:] + H = H - H % 2 + W = W - W % 2 + return img[:, :, :H, :W] + +def normalized_focal_loss(pred, gt, gamma=2, class_num=3, norm=True, beta_detach=False, beta_sum_detach=False): + pred_logits = F.softmax(pred, dim=1) # [B, 3, H, W] + gt_one_hot = F.one_hot(gt, class_num).permute(0, 3, 1, 2) # [B, 3, H, W] + p = (pred_logits * gt_one_hot).sum(dim=1) # [B, H, W] + beta = (1 - p) ** gamma # [B, H, W] + beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) / (pred.shape[-1] * pred.shape[-2]) # [B, 1, 1] + + if beta_detach: + beta = beta.detach() + if beta_sum_detach: + beta_sum = beta_sum.detach() + + if norm: + loss = 1 / beta_sum * beta * (-torch.log(p)) + return torch.mean(loss) + else: + loss = beta * (-torch.log(p)) + return torch.mean(loss) + +class GHMC(nn.Module): + def __init__(self, bins=10, momentum=0.75, loss_weight=1.0, device='cuda', norm=False): + super(GHMC, self).__init__() + self.bins = bins + self.momentum = momentum + self.edges = torch.arange(bins + 1).float().cuda() / bins + self.edges[-1] += 1e-6 + if momentum > 0: + self.acc_sum = torch.zeros(bins).cuda() + self.loss_weight = loss_weight + self.device = device + self.norm = norm + + def forward(self, pred, target, *args, **kwargs): + """Calculate the GHM-C loss. + Args: + pred (float tensor of size [batch_num, class_num]): + The direct prediction of classification fc layer. + target (float tensor of size [batch_num, class_num]): + Binary class target for each sample. + label_weight (float tensor of size [batch_num, class_num]): + the value is 1 if the sample is valid and 0 if ignored. + Returns: + The gradient harmonized loss. + """ + + # the target should be binary class label + # if pred.dim() != target.dim(): + # target, label_weight = _expand_binary_labels( + # target, label_weight, pred.size(-1)) + # target, label_weight = target.float(), label_weight.float() + # pdb.set_trace() + + # pred: [B, C, H, W], target: [B, H, W] + pred = pred.permute(0, 2, 3, 1).reshape(-1, 3) # [B x H x W, C] + target = target.reshape(-1) # [B x H x W] + # self.acc_sum = self.acc_sum.type(pred.dtype) + + edges = self.edges + mmt = self.momentum + weights = torch.zeros((target.shape),dtype=pred.dtype).to(self.device) + + # gradient length + #g = 1 - torch.index_select(F.softmax(pred,dim=1).detach(), dim=0, index=target) + g = 1 - torch.gather(F.softmax(pred,dim=1).detach(),dim=1,index=target.unsqueeze(1)) + #g = torch.abs(pred.softmax(2).detach() - target) + + tot = 1.0 + n = 0 # n valid bins + for i in range(self.bins): + inds = (g >= edges[i]) & (g < edges[i+1]) + num_in_bin = inds.sum().item() + if num_in_bin > 0: + idx = torch.nonzero(inds)[:, 0] + if mmt > 0: + self.acc_sum[i] = mmt * self.acc_sum[i] \ + + (1 - mmt) * num_in_bin + # pdb.set_trace()#scatter_ index_put_ + #BB=torch.nonzero(inds) + _weight_idx = tot / self.acc_sum[i] + weights = weights.to(dtype=_weight_idx.dtype) + weights[idx] = _weight_idx + # weights.scatter_(0, torch.nonzero(inds)[:,0], tot / self.acc_sum[i]) + # # weights.index_put_(inds, tot / self.acc_sum[i]) + # weights[inds] = tot / self.acc_sum[i] # * torch.ones((len(inds))) + else: + weights[idx] = tot / num_in_bin + n += 1 + if n > 0: + weights = weights / n + + # pdb.set_trace() + # loss = (weights * F.cross_entropy(pred, target, reduction='none')).sum() / tot / pred.shape[0] + if self.norm: + weights = weights / torch.sum(weights).detach() + + loss = - ((weights.unsqueeze(1) * torch.gather(F.log_softmax(pred, dim=1), dim=1, index=target.unsqueeze(1))).sum() ) # / pred.shape[0] + + # loss3= F.cross_entropy(pred, target, reduction='mean') + # loss4 = - ((torch.gather(F.log_softmax(pred, dim=1), dim=1, index=target.unsqueeze(1))).sum() / pred.shape[0]) + + # pro = F.softmax(logits, dim=1) + # + # label_onehot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), 1) + # with torch.no_grad(): + # weight_matrix = (1 - pro) ** self.gamma + # # pdb.set_trace() + # fl = - (weight_matrix * (label_onehot * (pro + self.eps).log())).sum() / pro.shape[0] + + return loss + +if __name__ == '__main__': + pred = torch.randn(2, 3, 1024, 1024) + gt =torch.argmax(torch.randn(2, 3, 1024, 1024), dim=1) + loss = normalized_focal_loss(pred, gt) + print(loss) + + + diff --git a/modeling/decoder/__init__.py b/modeling/decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8cd2a43e5205320ded836374f6c823a90714c30c --- /dev/null +++ b/modeling/decoder/__init__.py @@ -0,0 +1 @@ +from .detail_capture import Detail_Capture, Ori_Detail_Capture \ No newline at end of file diff --git a/modeling/decoder/__pycache__/__init__.cpython-38.pyc b/modeling/decoder/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3769ef0cda17fb5f15cef2f9a9d3deb0799ac15 Binary files /dev/null and b/modeling/decoder/__pycache__/__init__.cpython-38.pyc differ diff --git a/modeling/decoder/__pycache__/detail_capture.cpython-38.pyc b/modeling/decoder/__pycache__/detail_capture.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd1eb6bfc6209e706f0a4a9a24c627ba9ad8ea70 Binary files /dev/null and b/modeling/decoder/__pycache__/detail_capture.cpython-38.pyc differ diff --git a/modeling/decoder/__pycache__/unet_detail_capture.cpython-38.pyc b/modeling/decoder/__pycache__/unet_detail_capture.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..691371d22ddc50efe3f802294ec7373379b4ad3c Binary files /dev/null and b/modeling/decoder/__pycache__/unet_detail_capture.cpython-38.pyc differ diff --git a/modeling/decoder/detail_capture.py b/modeling/decoder/detail_capture.py new file mode 100644 index 0000000000000000000000000000000000000000..fd8b1d68433473a5d787de3ba53efa35ff9bfbcc --- /dev/null +++ b/modeling/decoder/detail_capture.py @@ -0,0 +1,185 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class Basic_Conv3x3(nn.Module): + """ + Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers. + """ + def __init__( + self, + in_chans, + out_chans, + stride=2, + padding=1, + ): + super().__init__() + self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False) + self.bn = nn.BatchNorm2d(out_chans) + self.relu = nn.ReLU(True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + + return x + +class ConvStream(nn.Module): + """ + Simple ConvStream containing a series of basic conv3x3 layers to extract detail features. + """ + def __init__( + self, + in_chans = 4, + out_chans = [48, 96, 192], + ): + super().__init__() + self.convs = nn.ModuleList() + + self.conv_chans = out_chans.copy() + self.conv_chans.insert(0, in_chans) + + for i in range(len(self.conv_chans)-1): + in_chan_ = self.conv_chans[i] + out_chan_ = self.conv_chans[i+1] + self.convs.append( + Basic_Conv3x3(in_chan_, out_chan_) + ) + + def forward(self, x): + out_dict = {'D0': x} + for i in range(len(self.convs)): + x = self.convs[i](x) + name_ = 'D'+str(i+1) + out_dict[name_] = x + + return out_dict + +class Fusion_Block(nn.Module): + """ + Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer. + """ + def __init__( + self, + in_chans, + out_chans, + ): + super().__init__() + self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1) + + def forward(self, x, D): + F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + out = torch.cat([D, F_up], dim=1) + out = self.conv(out) + + return out + +class Matting_Head(nn.Module): + """ + Simple Matting Head, containing only conv3x3 and conv1x1 layers. + """ + def __init__( + self, + in_chans = 32, + mid_chans = 16, + ): + super().__init__() + self.matting_convs = nn.Sequential( + nn.Conv2d(in_chans, mid_chans, 3, 1, 1), + nn.BatchNorm2d(mid_chans), + nn.ReLU(True), + nn.Conv2d(mid_chans, 1, 1, 1, 0) + ) + + def forward(self, x): + x = self.matting_convs(x) + + return x + +class Detail_Capture(nn.Module): + """ + Simple and Lightweight Detail Capture Module for ViT Matting. + """ + def __init__( + self, + in_chans = [384, 1], + img_chans=4, + convstream_out = [48, 96, 192], + fusion_out = [256, 128, 64, 32], + ): + super().__init__() + assert len(fusion_out) == len(convstream_out) + 1 + + self.convstream = ConvStream(in_chans=img_chans, out_chans=convstream_out) + self.conv_chans = self.convstream.conv_chans # [4, 48, 96, 192] + + self.fusion_blks = nn.ModuleList() + self.fus_channs = fusion_out.copy() + self.fus_channs.insert(0, in_chans[0]) # [384, 256, 128, 64, 32] + for i in range(len(self.fus_channs)-1): + in_channels = self.fus_channs[i] + self.conv_chans[-(i+1)] if i != 2 else in_chans[1] + self.conv_chans[-(i+1)] # [256 + 192 = 448, 256 + 96 = 352, 128 + 48 = 176, 64 + 4 = 68] + out_channels = self.fus_channs[i+1] # [256, 128, 64, 32] + self.fusion_blks.append( + Fusion_Block( + in_chans = in_channels, + out_chans = out_channels, + ) + ) + + self.matting_head = Matting_Head( # 32 --> 1 + in_chans = fusion_out[-1], + ) + + def forward(self, features, images): + detail_features = self.convstream(images) # [1, 4, 672, 992] --> D0: [1, 4, 672, 992], D1: [1, 48, 336, 496], D2: [1, 96, 168, 248], D3: [1, 192, 84, 124] + for i in range(len(self.fusion_blks)): # D3 + d_name_ = 'D'+str(len(self.fusion_blks)-i-1) + features = self.fusion_blks[i](features, detail_features[d_name_]) + + phas = torch.sigmoid(self.matting_head(features)) + + return {'phas': phas} + + +class Ori_Detail_Capture(nn.Module): + """ + Simple and Lightweight Detail Capture Module for ViT Matting. + """ + def __init__( + self, + in_chans = 384, + img_chans=4, + convstream_out = [48, 96, 192], + fusion_out = [256, 128, 64, 32], + ): + super().__init__() + assert len(fusion_out) == len(convstream_out) + 1 + + self.convstream = ConvStream(in_chans = img_chans) + self.conv_chans = self.convstream.conv_chans + + self.fusion_blks = nn.ModuleList() + self.fus_channs = fusion_out.copy() + self.fus_channs.insert(0, in_chans) + for i in range(len(self.fus_channs)-1): + self.fusion_blks.append( + Fusion_Block( + in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)], + out_chans = self.fus_channs[i+1], + ) + ) + + self.matting_head = Matting_Head( + in_chans = fusion_out[-1], + ) + + def forward(self, features, images): + detail_features = self.convstream(images) + for i in range(len(self.fusion_blks)): + d_name_ = 'D'+str(len(self.fusion_blks)-i-1) + features = self.fusion_blks[i](features, detail_features[d_name_]) + + phas = torch.sigmoid(self.matting_head(features)) + + return {'phas': phas} diff --git a/modeling/decoder/unet_detail_capture.py b/modeling/decoder/unet_detail_capture.py new file mode 100644 index 0000000000000000000000000000000000000000..79d66ea3f8cd074881c74e8e5452da1e67de64d0 --- /dev/null +++ b/modeling/decoder/unet_detail_capture.py @@ -0,0 +1,429 @@ +import cv2 +import torch +from torch import nn +from torch.nn import functional as F +# from nnMorpho.binary_operators import erosion +from detectron2.layers.batch_norm import NaiveSyncBatchNorm + + +class GenTrimapTorch(object): + def __init__(self, max_kernal=200): + self.max_kernal = max_kernal + self.erosion_kernels = [None] + [torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size))).float().cuda() for size in range(1, self.max_kernal)] + + def __call__(self, mask, kernel_size): + + fg_width = kernel_size + bg_width = kernel_size + + fg_mask = mask + bg_mask = 1 - mask + + fg_mask = erosion(fg_mask, self.erosion_kernels[fg_width], border='a') + bg_mask = erosion(bg_mask, self.erosion_kernels[bg_width], border='a') + + trimap = torch.ones_like(mask) * 0.5 + trimap[fg_mask == 1] = 1.0 + trimap[bg_mask == 1] = 0.0 + + return trimap + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class BasicDownBlock(nn.Module): + def __init__(self, in_channel, out_channel, res = True, norm=LayerNorm2d, block_num=1, kernel_size=3): + super().__init__() + + self.res = res + self.basic_layer = nn.ModuleList() + for i in range(block_num): + if i == 0: + basic_layer_in_ch = in_channel + stride = 2 + else: + basic_layer_in_ch = out_channel + stride = 1 + self.basic_layer.append(nn.GELU()) + self.basic_layer.append(nn.Sequential( + nn.Conv2d(basic_layer_in_ch, out_channel, kernel_size, stride, kernel_size // 2), + norm(out_channel), + nn.GELU(), + nn.Conv2d(out_channel, out_channel, kernel_size, 1, kernel_size // 2), + norm(out_channel), + )) + self.act = nn.GELU() + + if self.res: + self.res_layer = nn.Conv2d(in_channel, out_channel, kernel_size, 2, kernel_size // 2) + + def forward(self, x): + + if self.res: + identity = self.res_layer(x) + else: + identity = F.interpolate(x, size=(out.shape[-2], out.shape[-1]), mode='bilinear', align_corners=False) + + out = x + for layer in self.basic_layer: + out = layer(out) + + out = out + identity + out = self.act(out) + + return out + + +class BasicUpBlock(nn.Module): + + def __init__( self, in_channel, out_channel, res = True, skip_connect = 'concat', norm=LayerNorm2d, block_num=1, kernel_size=3): + super().__init__() + assert skip_connect in {'sum', 'concat'} + + self.res = res + self.skip_connect = skip_connect + self.basic_layer = nn.ModuleList() + for i in range(block_num): + if i == 0: + basic_layer_in_ch = in_channel + first_conv = nn.ConvTranspose2d(basic_layer_in_ch, out_channel, 2, 2) + else: + basic_layer_in_ch = out_channel + first_conv = nn.Conv2d(out_channel, out_channel, kernel_size, 1, kernel_size // 2) + self.basic_layer.append(nn.GELU()) + self.basic_layer.append(nn.Sequential( + first_conv, + norm(out_channel), + nn.GELU(), + nn.Conv2d(out_channel, out_channel, kernel_size, 1, kernel_size // 2), + norm(out_channel), + )) + self.act = nn.GELU() + + if self.res: + self.res_layer = nn.Conv2d(in_channel, out_channel, kernel_size, 1, kernel_size // 2) + + + def forward(self, x, skip_feat, concat_feat=None): + + if self.skip_connect == 'sum': + x = x + skip_feat + else: + x = torch.concat((x, skip_feat), dim=1) + + if concat_feat is not None: + x = torch.concat((x, concat_feat), dim=1) + + out = x + for layer in self.basic_layer: + out = layer(out) + # out = self.basic_layer(x) + + identity = F.interpolate(x, size=(out.shape[-2], out.shape[-1]), mode='bilinear', align_corners=False) + if self.res: + identity = self.res_layer(identity) + + out = out + identity + out = self.act(out) + + return out + + + +class DetailUNet(nn.Module): + def __init__( + self, + img_feat_in = 4, + vit_early_feat_in = 768, + matting_feat_in = 5, + downsample_in_out = [(4, 32), (32, 64), (64, 128), (128, 256)], + upsample_in_out = [(256, 128), (128, 64), (64, 32), (32, 16)], + matting_head_in = 16, + skip_connect = 'sum', + norm_type = 'LN', + ): + super().__init__() + + assert len(downsample_in_out) == len(upsample_in_out) + downsample_in_out[0] = (img_feat_in, downsample_in_out[0][1]) + + assert norm_type in {'BN', 'LN', 'SyncBN'} + if norm_type == 'BN': + self.norm = torch.nn.BatchNorm2d + elif norm_type == 'SyncBN': + self.norm = NaiveSyncBatchNorm + else: + self.norm = LayerNorm2d + + self.down_blks = nn.ModuleList() + for in_ch, out_ch in downsample_in_out: + self.down_blks.append( + BasicDownBlock(in_ch, out_ch, norm=self.norm) + ) + + self.mid_layer = nn.Sequential( + nn.Conv2d(vit_early_feat_in, downsample_in_out[-1][1], 1, 1), + self.norm(downsample_in_out[-1][1]), + nn.GELU(), + ) + + self.up_blks = nn.ModuleList() + for i, (in_ch, out_ch) in enumerate(upsample_in_out): + if i == 2: + in_ch += matting_feat_in + self.up_blks.append( + BasicUpBlock(in_ch, out_ch, skip_connect=skip_connect, norm=self.norm) + ) + + self.matting_head = nn.Conv2d(matting_head_in, 1, 3, 1, 1) + + + def forward(self, x, vit_early_feat, matting_feat, return_alpha_logits=False): + details = [] + dfeatures = x + + for i in range(len(self.down_blks)): + dfeatures = self.down_blks[i](dfeatures) + details.append(dfeatures) + + out = self.mid_layer(vit_early_feat) + for i in range(len(self.up_blks)): + if i == 2: + out = self.up_blks[i](out, details[-i - 1], matting_feat) + else: + out = self.up_blks[i](out, details[-i - 1]) + alpha = self.matting_head(out) + if return_alpha_logits: + return alpha, out + else: + return alpha + + +class MattingDetailDecoder(nn.Module): + def __init__( + self, + img_feat_in = 4, + vit_intern_feat_in = 1024, + vit_intern_feat_index = [0, 1, 2, 3], + downsample_in_out = [(4, 32), (32, 64), (64, 128), (128, 256)], + upsample_in_out = [(256, 128), (128, 64), (64, 32), (32, 16)], + matting_head_in = 16, + skip_connect = 'sum', + norm_type = 'BN', + norm_mask_logits = 6.5, + with_trimap = False, + min_kernel_size = 20, + kernel_div = 10, + concat_gen_trimap = False, + wo_hq_features = False, + block_num = 1, + wo_big_kernel = False, + sam2_multi_scale_feates = False, + ): + super().__init__() + + assert len(downsample_in_out) == len(upsample_in_out) + assert skip_connect in {'sum', 'concat'} + downsample_in_out[0] = (img_feat_in, downsample_in_out[0][1]) + + self.vit_intern_feat_in = vit_intern_feat_in + self.vit_intern_feat_index = vit_intern_feat_index + self.norm_mask_logits = norm_mask_logits + self.with_trimap = with_trimap + self.min_kernel_size = min_kernel_size + self.kernel_div = kernel_div + self.concat_gen_trimap = concat_gen_trimap + self.wo_hq_features = wo_hq_features + self.block_num = block_num + self.wo_big_kernel = wo_big_kernel + self.sam2_multi_scale_feates = sam2_multi_scale_feates + if self.sam2_multi_scale_feates: + assert downsample_in_out[0][0] == 6 + downsample_in_out = [(4, 32), (32, 64), (64 + 32, 128), (128 + 64, 256)] + upsample_in_out = [(256, 128), (128, 64), (64, 32), (32, 16)] + + if self.with_trimap and not self.concat_gen_trimap: + self.gen_trimap = GenTrimapTorch() + assert norm_type in {'BN', 'LN', 'SyncBN'} + if norm_type == 'BN': + self.norm = torch.nn.BatchNorm2d + elif norm_type == 'SyncBN': + self.norm = NaiveSyncBatchNorm + else: + self.norm = LayerNorm2d + + if self.block_num >= 2 and not self.wo_big_kernel: + self.big_kernel_process = nn.Sequential( + nn.Conv2d(img_feat_in, 16, kernel_size=13, stride=1, padding=6), + self.norm(16), + nn.GELU(), + nn.Conv2d(16, 32, kernel_size=13, stride=1, padding=6), + self.norm(32), + nn.GELU(), + ) + downsample_in_out[0] = (32, downsample_in_out[0][1]) + + if not self.sam2_multi_scale_feates: + self.vit_feat_proj = nn.ModuleDict() + for idx in self.vit_intern_feat_index: + self.vit_feat_proj[str(idx)] = nn.Conv2d(self.vit_intern_feat_in, self.vit_intern_feat_in // len(self.vit_intern_feat_index), 1, 1) + self.vit_feat_aggregation = nn.Sequential( + nn.Conv2d(self.vit_intern_feat_in // len(self.vit_intern_feat_index) * len(self.vit_intern_feat_index), downsample_in_out[-1][1], 3, 1, 1), + self.norm(downsample_in_out[-1][1]), + nn.GELU(), + ) + + self.down_blks = nn.ModuleList() + for in_ch, out_ch in downsample_in_out: + self.down_blks.append( + BasicDownBlock(in_ch, out_ch, norm=self.norm, block_num=self.block_num, kernel_size=5 if self.block_num >= 2 else 3) + ) + + if self.sam2_multi_scale_feates: + self.mid_layer = nn.ModuleList([ + nn.Sequential( + nn.Conv2d(32, 32, 1, 1), + self.norm(32), + nn.GELU(), + ), + nn.Sequential( + nn.Conv2d(64, 64, 1, 1), + self.norm(64), + nn.GELU(), + ), + nn.Sequential( + nn.Conv2d(256, 256, 1, 1), + self.norm(256), + nn.GELU(), + ), + nn.Sequential( + nn.Conv2d(512, 256, 3, 1, 1), + self.norm(256), + nn.GELU(), + ), + ]) + else: + self.mid_layer = nn.Sequential( + nn.Conv2d(downsample_in_out[-1][1] * 2, downsample_in_out[-1][1], 1, 1), + self.norm(downsample_in_out[-1][1]), + nn.GELU(), + ) + + self.up_blks = nn.ModuleList() + for _, (in_ch, out_ch) in enumerate(upsample_in_out): + if skip_connect == 'concat': + self.up_blks.append(BasicUpBlock(in_ch * 2, out_ch, skip_connect=skip_connect, norm=self.norm, block_num=self.block_num)) + else: + self.up_blks.append(BasicUpBlock(in_ch, out_ch, skip_connect=skip_connect, norm=self.norm, block_num=self.block_num)) + + self.matting_head = nn.Conv2d(matting_head_in, 1, 3, 1, 1) + + if self.norm_mask_logits == 'BN': + self.logits_norm = self.norm(1) + + + def preprocess_inputs(self, images, hq_features, pred_trimap): + + if self.wo_hq_features: + return images + + if isinstance(self.norm_mask_logits, float): + norm_hq_features = hq_features / self.norm_mask_logits + elif self.norm_mask_logits == 'BN': + norm_hq_features = self.logits_norm(hq_features) + elif self.norm_mask_logits == 'Sigmoid': + if hq_features.shape[1] == 1: + norm_hq_features = torch.sigmoid(hq_features) + else: + norm_hq_features = torch.softmax(hq_features, dim=1) + elif self.norm_mask_logits: + norm_hq_features = hq_features / torch.std(hq_features, dim=(1, 2, 3), keepdim=True) + else: + norm_hq_features = hq_features + + if self.concat_gen_trimap: + pred_trimap = F.interpolate(pred_trimap, size=(images.shape[-2], images.shape[-1]), mode='bilinear', align_corners=False) + pred_trimap = torch.argmax(pred_trimap, dim=1, keepdim=True).float() / 2.0 + norm_hq_features = torch.concat((norm_hq_features, pred_trimap.detach()), dim=1) + elif self.with_trimap: + mask = (norm_hq_features > 0).float() + for i_batch in range(images.shape[0]): + mask_area = torch.sum(mask[i_batch]) + kernel_size = max(self.min_kernel_size, int((mask_area ** 0.5) / self.kernel_div)) + kernel_size = min(kernel_size, self.gen_trimap.max_kernal - 1) + mask[i_batch, 0] = self.gen_trimap(mask[i_batch, 0], kernel_size=kernel_size) + trimaps = mask + norm_hq_features = torch.concat((norm_hq_features, trimaps), dim=1) + + conditional_images = torch.concatenate((images, norm_hq_features), dim=1) + return conditional_images + + def forward(self, images, hq_features, vit_intern_feat, return_alpha_logits=False, pred_trimap=None): + + condition_input = self.preprocess_inputs(images, hq_features, pred_trimap) + + if not self.sam2_multi_scale_feates: + # aggregate 4 vit_intern_feat + # assert len(vit_intern_feat) == self.vit_intern_feat_num + vit_feats = [] + for idx in self.vit_intern_feat_index: + vit_feats.append(self.vit_feat_proj[str(idx)](vit_intern_feat[idx].permute(0, 3, 1, 2))) + vit_feats = torch.concat(vit_feats, dim=1) + vit_aggregation_feats = self.vit_feat_aggregation(vit_feats) + + details = [] + dfeatures = condition_input + + if hasattr(self, 'big_kernel_process'): + dfeatures = self.big_kernel_process(dfeatures) + + for i in range(len(self.down_blks)): + if self.sam2_multi_scale_feates: + if i == 2: + dfeatures = torch.concat((dfeatures, self.mid_layer[0](vit_intern_feat['high_res_feats'][0])), dim=1) + elif i == 3: + dfeatures = torch.concat((dfeatures, self.mid_layer[1](vit_intern_feat['high_res_feats'][1])), dim=1) + dfeatures = self.down_blks[i](dfeatures) + details.append(dfeatures) + + if self.sam2_multi_scale_feates: + out = torch.concat((details[-1], self.mid_layer[2](vit_intern_feat['image_embed'])), dim=1) + out = self.mid_layer[3](out) + else: + out = self.mid_layer(torch.concat((details[-1], vit_aggregation_feats), dim=1)) + for i in range(len(self.up_blks)): + out = self.up_blks[i](out, details[-i - 1]) + alpha = torch.sigmoid(self.matting_head(out)) + if return_alpha_logits: + return alpha, out + else: + return alpha + + + +if __name__ == '__main__': + + from engine.mattingtrainer import parameter_count_table + + model = MattingDetailDecoder(img_feat_in = 5, vit_intern_feat_index=[0]) + x = torch.randn((2, 5, 1024, 1024)) + hq_features = torch.randn((2, 1, 1024, 1024)) + vit_feat = [torch.randn((2, 64, 64, 1024)) for _ in range(4)] + + out = model(x, hq_features, vit_feat) + print(out.shape) + + print("Trainable parameters: \n" + parameter_count_table(model, trainable_only=True, max_depth=5)) diff --git a/modeling/meta_arch/__init__.py b/modeling/meta_arch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87aa29eb1547d347cd6e41fd1d831c318908be47 --- /dev/null +++ b/modeling/meta_arch/__init__.py @@ -0,0 +1 @@ +from .sam_hq_matting import SamHqMatte \ No newline at end of file diff --git a/modeling/meta_arch/__pycache__/__init__.cpython-38.pyc b/modeling/meta_arch/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc7fd1eb6507a9c9f894d49888d8353b22ca0c5a Binary files /dev/null and b/modeling/meta_arch/__pycache__/__init__.cpython-38.pyc differ diff --git a/modeling/meta_arch/__pycache__/sam_hq_matting.cpython-38.pyc b/modeling/meta_arch/__pycache__/sam_hq_matting.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..155e8015bd75aa6a8740e168b880463db3b9fb5b Binary files /dev/null and b/modeling/meta_arch/__pycache__/sam_hq_matting.cpython-38.pyc differ diff --git a/modeling/meta_arch/sam_hq_matting.py b/modeling/meta_arch/sam_hq_matting.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a7bca52d5c3f5ab16201dd9e1453f8650b96f0 --- /dev/null +++ b/modeling/meta_arch/sam_hq_matting.py @@ -0,0 +1,671 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import os +import numpy as np +from PIL import Image +from copy import deepcopy +from collections import defaultdict + +from detectron2.structures import ImageList +from detectron2.utils.comm import get_local_rank +from modeling.semantic_enhanced_matting.predictor import SamPredictor +from modeling.semantic_enhanced_matting.condition_conv import ConditionConv, ConditionEmbedding, ConditionAdd, BBoxEmbedInteract, BBoxInteract, BBoxInteractInOut +from modeling.semantic_enhanced_matting.modeling.image_encoder import PatchEmbed +from modeling.semantic_enhanced_matting.modeling.common import LayerNorm2d +from modeling.decoder.unet_detail_capture import MattingDetailDecoder +from modeling.semantic_enhanced_matting.feature_fusion import FeatureFusion +from sam2.sam2_image_predictor import SAM2ImagePredictor + +from modeling.semantic_enhanced_matting.modeling.mask_decoder_hq_matting import MaskDecoderHQMatting +from modeling.semantic_enhanced_matting.modeling import TwoWayTransformer + +from peft import LoraConfig, get_peft_model +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer + +from data.rand_augment import RandAugment +import random +import kornia.filters as kf + + +class SamHqMatte(nn.Module): + + target_length = 1024 + + def __init__( + self, + *, + sam_model, + hq_token_only, + hq_features_type, + matting_decoder, + criterion, + pixel_mean, + pixel_std, + multimask_output=False, + vis_period=None, + output_dir=None, + lora_rank = None, + lora_alpha = None, + lora_target_modules = ["qkv", "proj"], + lora_dropout = 0.1, + w_dora = False, + w_rslora = False, + lora_on_mask_decoder = False, + frozen_sam_hq_reg = None, + reg_margin = 0.85, + w_attention_mask = False, + alpha_reg_range = None, + alpha_reg_weight = 1.0, + coconut_pl = False, + coconut_pl_alpha = 1.0, + coconut_self_training = False, + eval_w_sam_hq_mask = False, + backbone_condition = False, + condition_wo_conv = False, + w_only_bbox_cond = False, + coconut_only_known_l1 = False, + backbone_bbox_prompt = None, + backbone_bbox_prompt_loc = [2, 3], + backbone_bbox_prompt_loss_weight = 1.0, + concat_gen_trimap = False, + multi_matting_decoder = None, + w_all_logits = False, + bbox_prompt_all_block = None, + matting_token = False, + test_w_hq_token = False, + sam_hq_token_reg = None, + feat_cross_attn_fusion = False, + trimap_loss_type = None, + reg_on_sam_logits = False, + reg_w_bce_loss = False, + complex_trimap_pred_layer = False, + matting_token_sup = None, + matting_token_sup_loss_weight = None, + sam2 = False, + ): + super(SamHqMatte, self).__init__() + + self.sam_model = sam_model + self.sam_predictor = SamPredictor(self.sam_model) if not sam2 else SAM2ImagePredictor(self.sam_model) # already in eval mode and no_grad + self.hq_token_only = hq_token_only + self.multimask_output = multimask_output + self.hq_features_type = hq_features_type + + self.matting_decoder = matting_decoder + + self.criterion = criterion + + self.register_buffer( + "pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False + ) + self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) + assert ( + self.pixel_mean.shape == self.pixel_std.shape + ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" + + self.vis_period = vis_period + if output_dir is not None and output_dir != '?': + self.output_dir = os.path.join(output_dir, 'vis_results') + os.makedirs(self.output_dir, exist_ok=True) + self.train_iter_index = 0 + + self.lora_rank = lora_rank + self.lora_alpha = lora_alpha + self.lora_target_modules = lora_target_modules + self.lora_dropout = lora_dropout + self.w_dora = w_dora + self.w_rslora = w_rslora + self.lora_on_mask_decoder = lora_on_mask_decoder + self.frozen_sam_hq_reg = frozen_sam_hq_reg + self.reg_margin = reg_margin + self.w_attention_mask = w_attention_mask + self.alpha_reg_range = alpha_reg_range + self.alpha_reg_weight = alpha_reg_weight + self.coconut_pl = coconut_pl + self.coconut_pl_alpha = coconut_pl_alpha + self.coconut_self_training = coconut_self_training + self.eval_w_sam_hq_mask = eval_w_sam_hq_mask + self.backbone_condition = backbone_condition + self.condition_wo_conv = condition_wo_conv + self.w_only_bbox_cond = w_only_bbox_cond + self.coconut_only_known_l1 = coconut_only_known_l1 + self.backbone_bbox_prompt = backbone_bbox_prompt + self.backbone_bbox_prompt_loc = backbone_bbox_prompt_loc + self.backbone_bbox_prompt_loss_weight = backbone_bbox_prompt_loss_weight + self.concat_gen_trimap = concat_gen_trimap + self.multi_matting_decoder = multi_matting_decoder + self.w_all_logits = w_all_logits + self.bbox_prompt_all_block = bbox_prompt_all_block + self.matting_token = matting_token + self.test_w_hq_token = test_w_hq_token + self.sam_hq_token_reg = sam_hq_token_reg + self.feat_cross_attn_fusion = feat_cross_attn_fusion + self.trimap_loss_type = trimap_loss_type + self.reg_on_sam_logits = reg_on_sam_logits + self.reg_w_bce_loss = reg_w_bce_loss + self.complex_trimap_pred_layer = complex_trimap_pred_layer + self.matting_token_sup = matting_token_sup + self.sam2 = sam2 + assert self.matting_token_sup in {'alpha', 'trimap', None} + self.matting_token_sup_loss_weight = matting_token_sup_loss_weight + if self.matting_token_sup is not None: + assert self.backbone_bbox_prompt in {'bbox', None} + if self.frozen_sam_hq_reg is not None: + assert self.lora_rank is not None + if self.w_attention_mask: + self.attention_head = deepcopy(self.matting_decoder) + if self.coconut_self_training: + self.rand_aug = RandAugment(3,6) + self.warm_iter_coconut_self_training = 5000 + if self.backbone_condition: + assert self.lora_rank is not None + if self.backbone_bbox_prompt is not None: + assert self.lora_rank is not None + if self.w_all_logits: + self.sam_predictor.model.mask_decoder.w_all_logits = True + if self.bbox_prompt_all_block: + assert self.lora_rank is not None + if self.matting_token and not self.sam2: + self.sam_predictor.model.mask_decoder.hq_token_only = self.hq_token_only + + @property + def device(self): + return self.pixel_mean.device + + def init_lora(self, model=None): + if model is not None and self.lora_rank >= 1: + if self.lora_on_mask_decoder: + self.lora_target_modules += ["q_proj", "k_proj", "v_proj", "out_proj"] + modules_to_save = None + else: + modules_to_save = ['matting_decoder'] + + lora_config = LoraConfig( + r=self.lora_rank, + lora_alpha=self.lora_alpha, + use_rslora=self.w_rslora, + use_dora=self.w_dora, + init_lora_weights="gaussian", + target_modules=self.lora_target_modules, + lora_dropout=self.lora_dropout, + modules_to_save=modules_to_save + ) + model = get_peft_model(model, lora_config) + if self.lora_on_mask_decoder: + for n, p in model.matting_decoder.named_parameters(): + if n.split('modules_to_save.default.')[-1] in model.matting_decoder.trainable_params_str: + p.requires_grad = True + else: + for n, p in model.matting_decoder.named_parameters(): + if n.split('modules_to_save.default.')[-1] in model.matting_decoder.frozen_params_str: + p.requires_grad = False + return model + elif self.lora_rank >= 1: + lora_config = LoraConfig( + r=self.lora_rank, + lora_alpha=self.lora_alpha, + use_rslora=self.w_rslora, + use_dora=self.w_dora, + init_lora_weights="gaussian", + target_modules=self.lora_target_modules, + lora_dropout=self.lora_dropout, + ) + self.sam_predictor.model.image_encoder = get_peft_model(self.sam_predictor.model.image_encoder, lora_config) + + if self.sam2: + for n, p in self.sam_predictor.model.image_encoder.named_parameters(): + if 'bbox_mask' in n: + p.requires_grad = True + + if self.backbone_condition: + if self.w_only_bbox_cond: + self.condition_embedding = ConditionEmbedding(condition_num = 4, pos_embedding_dim = 160) + else: + self.condition_embedding = ConditionEmbedding(condition_num = 5, pos_embedding_dim = 128) + + if self.condition_wo_conv: + self.condition_conv = nn.ModuleList([ConditionAdd() for _ in range(4)]) + else: + self.condition_conv = nn.ModuleList([ConditionConv( + in_channels = self.sam_predictor.model.image_encoder.embed_dim, + out_channels = self.sam_predictor.model.image_encoder.embed_dim, + bottleneck_channels = 512 + ) for _ in range(4)]) + + if self.backbone_bbox_prompt is not None and not self.sam2: + self.condition_layer = nn.ModuleDict() + self.condition_layer['patch_embed'] = PatchEmbed( + kernel_size=(self.sam_predictor.model.image_encoder.patch_size, self.sam_predictor.model.image_encoder.patch_size), + stride=(self.sam_predictor.model.image_encoder.patch_size, self.sam_predictor.model.image_encoder.patch_size), + in_chans=4, + embed_dim=self.sam_predictor.model.image_encoder.embed_dim, + ) + if self.multi_matting_decoder is None: + if self.backbone_bbox_prompt in {'trimap', 'alpha_trimap'}: + transformer_dim = self.sam_predictor.model.image_encoder.embed_dim + for i in self.backbone_bbox_prompt_loc: + if self.complex_trimap_pred_layer: + self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 2, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 2), # 512 + nn.GELU(), + nn.Conv2d(transformer_dim // 2, transformer_dim // 4, kernel_size=3, stride=1, padding=1), + LayerNorm2d(transformer_dim // 4), # 256 + nn.GELU(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 8), # 128 + nn.GELU(), + nn.Conv2d(transformer_dim // 8, transformer_dim // 16, kernel_size=3, stride=1, padding=1), + LayerNorm2d(transformer_dim // 16), # 64 + nn.GELU(), + nn.Conv2d(transformer_dim // 16, 3, kernel_size=3, stride=1, padding=1), + ) + else: + self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + nn.GELU(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(transformer_dim // 8, 3, kernel_size=1, stride=1), + ) + elif self.backbone_bbox_prompt == 'alpha': + transformer_dim = self.sam_predictor.model.image_encoder.embed_dim + for i in self.backbone_bbox_prompt_loc: + if self.complex_trimap_pred_layer: + self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 2, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 2), # 512 + nn.GELU(), + nn.Conv2d(transformer_dim // 2, transformer_dim // 4, kernel_size=3, stride=1, padding=1), + LayerNorm2d(transformer_dim // 4), # 256 + nn.GELU(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 8), # 128 + nn.GELU(), + nn.Conv2d(transformer_dim // 8, transformer_dim // 16, kernel_size=3, stride=1, padding=1), + LayerNorm2d(transformer_dim // 16), # 64 + nn.GELU(), + nn.Conv2d(transformer_dim // 16, 1, kernel_size=3, stride=1, padding=1), + nn.Sigmoid() + ) + else: + self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + nn.GELU(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + nn.GELU(), + nn.Conv2d(transformer_dim // 8, 1, kernel_size=1, stride=1), + nn.Sigmoid() + ) + if self.bbox_prompt_all_block is not None: + if self.bbox_prompt_all_block == 'reuse_cross-self-attn': + self.condition_layer['prompt_layer'] = BBoxInteract( + position_point_embedding = deepcopy(self.sam_predictor.model.prompt_encoder.pe_layer), + point_weight = deepcopy(self.sam_predictor.model.prompt_encoder.point_embeddings) + ) + elif self.bbox_prompt_all_block == 'in-out-bbox_cross-self-attn': + self.condition_layer['prompt_layer'] = BBoxInteractInOut(downsample_rate = 2) + else: + embed_type, interact_type = self.bbox_prompt_all_block.split('_') + self.condition_layer['prompt_layer'] = BBoxEmbedInteract(embed_type, interact_type) + + if self.feat_cross_attn_fusion: + self.condition_layer['feature_fusion'] = FeatureFusion(in_channels=self.sam_predictor.model.image_encoder.embed_dim, attn_compression_ratio=8) + + def condition_bbox_and_instance_num(self): + self.sam_predictor.model.image_encoder.conv_necks = None + + def forward_samhq_and_matting_decoder(self, images, bbox, condition_proj=None, return_hq_token=False): + # get features from SAM image encoder + if self.sam2: + interm_features, sam2_logits, matting_logits, pred_trimap = self.forward_samhq(images, bbox, condition_proj) + sam2_logits = F.interpolate(sam2_logits, size=images.shape[-2:], mode='bilinear', align_corners=False) + matting_logits = F.interpolate(matting_logits, size=images.shape[-2:], mode='bilinear', align_corners=False) + sam_hq_matting_token = { + 'masks_hq': sam2_logits, + 'masks_matting': matting_logits + } + hq_features = matting_logits + low_res_masks = matting_logits + else: + if self.matting_token: + features, image_pe, sparse_embeddings, dense_embeddings, interm_features, sam_hq_matting_token, pred_trimap = self.forward_samhq(images, bbox, condition_proj) + if return_hq_token: + return sam_hq_matting_token['masks_hq'] + else: + if not self.training and self.test_w_hq_token: + low_res_masks, hq_features = sam_hq_matting_token['masks_hq'], sam_hq_matting_token['masks_hq'] + else: + low_res_masks, hq_features = sam_hq_matting_token['masks_matting'], sam_hq_matting_token['masks_matting'] + else: + features, image_pe, sparse_embeddings, dense_embeddings, interm_features, hq_features, sam_logits, low_res_masks, pred_trimap = self.forward_samhq(images, bbox, condition_proj) + if return_hq_token: + return hq_features + sam_hq_matting_token = {'masks_hq': hq_features, 'masks_sam': sam_logits} + + # get alpha from our proposed matting_decoder + if isinstance(self.matting_decoder, MattingDetailDecoder): + pred_alpha = self.matting_decoder( + images = images, + hq_features = hq_features, + vit_intern_feat = interm_features, + return_alpha_logits = (self.alpha_reg_range is not None), + pred_trimap = pred_trimap + ) + else: + pred_alpha = self.matting_decoder( + image_embeddings = features, # [B, 256, 64, 64] + image_pe = image_pe, + sparse_prompt_embeddings = sparse_embeddings, + dense_prompt_embeddings = dense_embeddings, + multimask_output = False, + interm_embeddings = interm_features, # [B, 256, 64, 64] + hq_features = hq_features, + images = images, + return_alpha_logits = (self.alpha_reg_range is not None), + pred_trimap = pred_trimap + ) + return low_res_masks, pred_alpha, pred_trimap, sam_hq_matting_token + + def forward(self, batched_inputs): # image: [1, 3, 643, 960]: 0.0~1.0, trimap: [1, 1, 643, 960]: 0.0~1.0 + + inputs = self.preprocess_inputs(batched_inputs) + images, bbox, gt_alpha, trimap, condition = inputs['images'], inputs['bbox'], inputs['alpha'], inputs['trimap'], inputs['condition'] + + if self.backbone_condition: + condition_proj = self.condition_embedding(condition) + elif self.backbone_bbox_prompt is not None or self.bbox_prompt_all_block is not None: + condition_proj = bbox + else: + condition_proj = None + + low_res_masks, pred_alpha, pred_trimap, sam_hq_matting_token = self.forward_samhq_and_matting_decoder(images, bbox, condition_proj) + + assert not self.training + if self.eval_w_sam_hq_mask: + self.sam_predictor.model.image_encoder.disable_adapter_layers() + with torch.no_grad(): + ori_features, ori_interm_features = self.sam_predictor.model.image_encoder(images) + samhq_low_res_masks = self.forward_samhq_others(images, bbox, ori_features, ori_interm_features)[-1] + samhq_low_res_masks = F.interpolate(samhq_low_res_masks, size=(images.shape[-2], images.shape[-1]), mode='bilinear', align_corners=False) + self.sam_predictor.model.image_encoder.enable_adapter_layers() + + return pred_alpha, samhq_low_res_masks + else: + return pred_alpha + + def forward_samhq_image_encoder(self, images, condition_proj=None): + if self.sam2: + backbone_out = self.sam_predictor.model.forward_image([images, condition_proj]) + _, vision_feats, _, _ = self.sam_predictor.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.sam_predictor.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.sam_predictor.model.no_mem_embed + feats = [ + feat.permute(1, 2, 0).view(feat.shape[1], -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self.sam_predictor._bb_feat_sizes[::-1]) + ][::-1] + return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}, None, None + else: + if self.backbone_condition: + condition_layer = self.condition_conv + elif self.backbone_bbox_prompt: + condition_layer = self.condition_layer + else: + condition_layer = None + # [B, 3, 1024, 1024]: -2. ~ 2. --> [B, 256, 64, 64], 4 x [B, 64, 64, 768] + features, interm_features, pred_trimap = self.sam_predictor.model.image_encoder(images, condition_proj, condition_layer) + return features, interm_features, pred_trimap + + # @torch.no_grad() + def forward_samhq_others(self, images, bbox, features, interm_features): + if self.sam2: + sam2_logits, matting_logits = self.sam_predictor.predict_batch_boxes_and_features(bbox, features) + return features, sam2_logits, matting_logits + + image_pe = self.sam_predictor.model.prompt_encoder.get_dense_pe() + + cat_sparse_embeddings = [] + cat_dense_prompt_embeddings = [] + cat_hq_features = [] + cat_sam_logits = [] + cat_low_res_masks = [] + cat_sam_hq_matting_token = defaultdict(list) + + for idx in range(images.shape[0]): + # get hq_features from SAM_HQ mask decoder + + # Embed prompts + sparse_embeddings, dense_embeddings = self.sam_predictor.model.prompt_encoder( + points=None, + # boxes=bbox[idx: idx + 1], + boxes=bbox[idx], # [N, 4] + masks=None, + ) # [B, 2, 256], [B, 256, 64, 64] + + # Predict masks + if isinstance(self.sam_predictor.model.mask_decoder, MaskDecoderHQMatting): + sam_hq_matting_token = self.sam_predictor.model.mask_decoder( + image_embeddings = features[idx: idx + 1], + image_pe = image_pe, + sparse_prompt_embeddings = sparse_embeddings, + dense_prompt_embeddings = dense_embeddings, + multimask_output = self.multimask_output, + interm_embeddings = [interm_feature[idx: idx + 1] for interm_feature in interm_features], + ) + for key in sam_hq_matting_token.keys(): + cat_sam_hq_matting_token[key].append(sam_hq_matting_token[key]) + else: + low_res_masks, masks_sam, hq_features = self.sam_predictor.model.mask_decoder( + image_embeddings = features[idx: idx + 1], + image_pe = image_pe, + sparse_prompt_embeddings = sparse_embeddings, + dense_prompt_embeddings = dense_embeddings, + multimask_output = self.multimask_output, + hq_token_only = self.hq_token_only, + interm_embeddings = [interm_feature[idx: idx + 1] for interm_feature in interm_features], + return_hq_features_type = self.hq_features_type + ) + cat_hq_features.append(hq_features) + cat_sam_logits.append(masks_sam) + cat_low_res_masks.append(low_res_masks) + + cat_sparse_embeddings.append(sparse_embeddings) + cat_dense_prompt_embeddings.append(dense_embeddings) + + sparse_embeddings = torch.stack(cat_sparse_embeddings, dim=0) # [B, 1, 2, 256] + dense_embeddings = torch.stack(cat_dense_prompt_embeddings, dim=0) # [B, 1, 256, 64, 64] + + if self.matting_token: + for key in cat_sam_hq_matting_token.keys(): + cat_sam_hq_matting_token[key] = torch.cat(cat_sam_hq_matting_token[key], dim=0) + cat_sam_hq_matting_token[key] = F.interpolate(cat_sam_hq_matting_token[key], size=images.shape[-2:], mode='bilinear', align_corners=False) + sam_hq_matting_token = cat_sam_hq_matting_token + return features, image_pe, sparse_embeddings, dense_embeddings, interm_features, sam_hq_matting_token + else: + hq_features = torch.cat(cat_hq_features, dim=0) # [B, 1, 256, 256] + low_res_masks = torch.cat(cat_low_res_masks, dim=0) # [B, 1, 256, 256] + hq_features = F.interpolate(hq_features, size=images.shape[-2:], mode='bilinear', align_corners=False) # [B, 1, 256, 256] --> [B, 1, 1024, 1024] + sam_logits = torch.cat(cat_sam_logits, dim=0) + sam_logits = F.interpolate(sam_logits, size=images.shape[-2:], mode='bilinear', align_corners=False) # [B, 1, 256, 256] --> [B, 1, 1024, 1024] + return features, image_pe, sparse_embeddings, dense_embeddings, interm_features, hq_features, sam_logits, low_res_masks + + def forward_samhq(self, images, bbox, condition_proj=None): + if self.lora_rank is None: + with torch.no_grad(): + features, interm_features, pred_trimap = self.forward_samhq_image_encoder(images, condition_proj) + else: + features, interm_features, pred_trimap = self.forward_samhq_image_encoder(images, condition_proj) + + return self.forward_samhq_others(images, bbox, features, interm_features) + (pred_trimap, ) + + def get_frozen_sam_logits(self, images, bbox, mask_type='hq'): + + if self.sam2: + features, _, _ = self.forward_samhq_image_encoder(images) + sam2_logits = self.sam_predictor.predict_batch_boxes_and_features(bbox, features, wo_matting_token=True) + sam2_logits = F.interpolate(sam2_logits, size=images.shape[-2:], mode='bilinear', align_corners=False) + return sam2_logits + + assert mask_type in {'hq', 'sam'} + features, interm_features, _ = self.forward_samhq_image_encoder(images) + image_pe = self.sam_predictor.model.prompt_encoder.get_dense_pe() + + cat_logits = [] + for idx in range(images.shape[0]): + sparse_embeddings, dense_embeddings = self.sam_predictor.model.prompt_encoder(points=None, boxes=bbox[idx], masks=None) + + low_res_masks, masks_sam, hq_features = self.sam_predictor.model.frozen_mask_decoder( + image_embeddings = features[idx: idx + 1], + image_pe = image_pe, + sparse_prompt_embeddings = sparse_embeddings, + dense_prompt_embeddings = dense_embeddings, + multimask_output = self.multimask_output, + hq_token_only = self.hq_token_only, + interm_embeddings = [interm_feature[idx: idx + 1] for interm_feature in interm_features], + return_hq_features_type = self.hq_features_type + ) + if mask_type == 'hq': + cat_logits.append(hq_features) + else: + cat_logits.append(masks_sam) + + logits = torch.cat(cat_logits, dim=0) # [B, 1, 256, 256] + logits = F.interpolate(logits, size=images.shape[-2:], mode='bilinear', align_corners=False) # [B, 1, 256, 256] --> [B, 1, 1024, 1024] + return logits + + def vis_training_results(self, **kwargs): + # images, bbox, trimap, low_res_masks, pred_alpha, alpha + self.train_iter_index += 1 + if self.train_iter_index % self.vis_period == 0: + batch_save_results = [] + save_path = os.path.join(self.output_dir, '{:06d}_rank{}.jpg'.format(self.train_iter_index, get_local_rank())) + + # [('images', (4, 3, 1024, 1024), -2.117904, 2.64), ('bbox', (4, 1, 4), 0.0, 1023.0), ('trimap', (4, 1, 1024, 1024), 0.0, 1.0), ('low_res_masks', (4, 1, 256, 256), -20.38, 10.15), ('pred_alpha', (4, 1, 1024, 1024), 0.1547, 0.791), ('alpha', (4, 1, 1024, 1024), 0.0, 1.0)] + for key in kwargs.keys(): + if key == 'bbox': + continue + # turn all tensor to [B, H, W, 3]: 0~255 np.int8 + if key == 'images': + kwargs[key] = kwargs[key] * self.pixel_std + self.pixel_mean + kwargs[key] = kwargs[key].permute(0, 2, 3, 1) * 255.0 + for i in range(kwargs['images'].shape[0]): + l, u, r, d = int(kwargs['bbox'][i, 0, 0].item()), int(kwargs['bbox'][i, 0, 1].item()), int(kwargs['bbox'][i, 0, 2].item()), int(kwargs['bbox'][i, 0, 3].item()) + red_line = torch.tensor([[255., 0., 0.]], device=kwargs[key].device, dtype=kwargs[key].dtype) + kwargs[key][i, u: d, l, :] = red_line + kwargs[key][i, u: d, r, :] = red_line + kwargs[key][i, u, l: r, :] = red_line + kwargs[key][i, d, l: r, :] = red_line + elif key in {'low_res_masks', 'frozen_hq_token'}: + if torch.max(kwargs[key]) <= 1: # coconut ori alpha + kwargs[key] = kwargs[key].permute(0, 2, 3, 1).repeat(1, 1, 1, 3) * 255.0 + else: + kwargs[key] = F.interpolate(kwargs[key], size=(kwargs['images'].shape[-3], kwargs['images'].shape[-2]), mode='bilinear', align_corners=False) + kwargs[key] = (kwargs[key] > self.sam_predictor.model.mask_threshold).float().permute(0, 2, 3, 1).repeat(1, 1, 1, 3) * 255.0 + else: + kwargs[key] = kwargs[key].permute(0, 2, 3, 1).repeat(1, 1, 1, 3) * 255.0 + + kwargs[key] = np.uint8(kwargs[key].detach().cpu().numpy()) + + for i in range(kwargs['images'].shape[0]): + save_results = [] + for key in kwargs.keys(): + if key != 'bbox': + save_results.append(kwargs[key][i]) + batch_save_results.append(np.concatenate(save_results, axis=1)) + + Image.fromarray(np.concatenate(batch_save_results, axis=0)).save(save_path) + + def preprocess_inputs(self, batched_inputs): + """ + Normalize, pad and batch the input images. + """ + output = dict() + + if "alpha" in batched_inputs: + alpha = batched_inputs["alpha"].to(self.device) + else: + alpha = None + + bbox = batched_inputs["bbox"].to(self.device) + + if self.training and self.coconut_self_training and sum([i == 'COCONut' for i in batched_inputs['dataset_name']]) >= 1: + output['coconut_ori_img'] = [] + output['coconut_trimap'] = [] + output['coconut_bbox'] = [] + output['coconut_idx'] = [] + for i, dataset_name in enumerate(batched_inputs['dataset_name']): + if dataset_name == 'COCONut': + # generate coconut_aug_img + img_np = np.uint8(batched_inputs["image"][i].permute(1, 2, 0).cpu().numpy() * 255.) + strong_aug_img = self.rand_aug(Image.fromarray(img_np), cutout = False) + strong_aug_img_tensor = torch.from_numpy(np.array(strong_aug_img)).to(self.device).permute(2, 0, 1)[None] / 255. + blur_kernel_sigma = 1.0 + random.random() # random from 1.0 ~ 2.0 + blur_filter = kf.GaussianBlur2d((101, 101), (blur_kernel_sigma, blur_kernel_sigma)) + blur_strong_aug_img_tensor = blur_filter(strong_aug_img_tensor)[0] + + output['coconut_ori_img'].append(batched_inputs["image"][i]) + batched_inputs["image"][i] = blur_strong_aug_img_tensor + + # generate coconut_trimap + coconut_mask = (alpha[i] != 0).float() + mask_area = torch.sum(coconut_mask) + kernel_size = max(self.matting_decoder.min_kernel_size, int((mask_area ** 0.5) / 7)) # self.matting_decoder.kernel_div + kernel_size = min(kernel_size, self.matting_decoder.gen_trimap.max_kernal - 1) + output['coconut_trimap'].append(self.matting_decoder.gen_trimap(coconut_mask[0], kernel_size=kernel_size)[None]) + + output['coconut_bbox'].append(bbox[i]) + output['coconut_idx'].append(i) + + output['coconut_ori_img'] = torch.stack(output['coconut_ori_img']).to(self.device) + output['coconut_ori_img'] = (output['coconut_ori_img'] - self.pixel_mean) / self.pixel_std + output['coconut_trimap'] = torch.stack(output['coconut_trimap']).to(self.device) + output['coconut_bbox'] = torch.stack(output['coconut_bbox']).to(self.device) + + images = batched_inputs["image"].to(self.device) + images = (images - self.pixel_mean) / self.pixel_std + assert images.shape[-2] == images.shape[-1] == 1024 + + if 'trimap' in batched_inputs.keys(): + trimap = batched_inputs["trimap"].to(self.device) + assert len(torch.unique(trimap)) <= 3 + else: + trimap = None + + output['images'] = images + output['bbox'] = bbox + output['alpha'] = alpha + output['trimap'] = trimap + + if 'hr_images' in batched_inputs.keys(): + hr_images = batched_inputs["hr_images"].to(self.device) + hr_images = (hr_images - self.pixel_mean) / self.pixel_std + _, _, H, W = hr_images.shape + if hr_images.shape[-1] % 16 != 0 or hr_images.shape[-2] % 16 != 0: + new_H = (16 - hr_images.shape[-2] % 16) + H if hr_images.shape[-2] % 16 != 0 else H + new_W = (16 - hr_images.shape[-1] % 16) + W if hr_images.shape[-1] % 16 != 0 else W + new_hr_images = torch.zeros((hr_images.shape[0], hr_images.shape[1], new_H, new_W)).to(self.device) + new_hr_images[:,:,:H,:W] = hr_images[:,:,:,:] + del hr_images + hr_images = new_hr_images + output['hr_images'] = hr_images + output['hr_images_ori_h_w'] = (H, W) + + if 'dataset_name' in batched_inputs.keys(): + output['dataset_name'] = batched_inputs["dataset_name"] + + if self.backbone_condition: + if self.w_only_bbox_cond: + output['condition'] = output['bbox'][:, 0, :] + else: + multi_fg_float = batched_inputs["multi_fg"].to(bbox.device).float()[:, None] * 512 + output['condition'] = torch.concat((output['bbox'][:, 0, :], multi_fg_float), dim=-1) + else: + output['condition'] = None + + return output diff --git a/modeling/semantic_enhanced_matting/__init__.py b/modeling/semantic_enhanced_matting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34433aa701e25ef6cb385b2de9ba7a82037822b4 --- /dev/null +++ b/modeling/semantic_enhanced_matting/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .build_sam import ( + build_sam, + build_sam_vit_h, + build_sam_vit_l, + build_sam_vit_b, + sam_model_registry, +) +from .build_sam_baseline import sam_model_registry_baseline +from .predictor import SamPredictor +from .automatic_mask_generator import SamAutomaticMaskGenerator +from .mask_decoder_matting import MaskDecoderMatting \ No newline at end of file diff --git a/modeling/semantic_enhanced_matting/__pycache__/__init__.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a8ced693e3c6b0b6fe71f8d222761b3204bf42d Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/__init__.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/__pycache__/automatic_mask_generator.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/automatic_mask_generator.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..397755a740bdc7ef3a58cf23d9a276a633e6d78e Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/automatic_mask_generator.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/__pycache__/build_sam.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/build_sam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6859e02f1a93531858c6d7deb5fc6bf9f510bbb7 Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/build_sam.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/__pycache__/build_sam_baseline.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/build_sam_baseline.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f254182bc71c3cf0d1751c5863b0531e79ab5f9b Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/build_sam_baseline.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/__pycache__/condition_conv.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/condition_conv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..761af4e81af4ad07402a1a3d701c113aa8fdbd41 Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/condition_conv.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/__pycache__/feature_fusion.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/feature_fusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..855fbbee0eb2a895a07ca1b5e676945349017c19 Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/feature_fusion.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/__pycache__/mask_decoder_matting.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/mask_decoder_matting.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d07377cf37c8e22364401350ac5a8e47b65ffce4 Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/mask_decoder_matting.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/__pycache__/predictor.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/predictor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7935cc5accf94b9cbd8604132d43ff8534f80de Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/predictor.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/automatic_mask_generator.py b/modeling/semantic_enhanced_matting/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..427ebebd831f848dfff219f695c45302228e449a --- /dev/null +++ b/modeling/semantic_enhanced_matting/automatic_mask_generator.py @@ -0,0 +1,374 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from typing import Any, Dict, List, Optional, Tuple + +from .modeling import Sam +from .predictor import SamPredictor +from .utils.amg import ( + MaskData, + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + @torch.no_grad() + def generate(self, image: np.ndarray, multimask_output: bool = True) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image, multimask_output) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray, multimask_output: bool = True) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, multimask_output) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + multimask_output: bool = True, + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size, multimask_output) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + multimask_output: bool = True, + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=multimask_output, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/modeling/semantic_enhanced_matting/build_sam.py b/modeling/semantic_enhanced_matting/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..242944a15470ae6975c0e893e7e498461844db58 --- /dev/null +++ b/modeling/semantic_enhanced_matting/build_sam.py @@ -0,0 +1,234 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer, TinyViT +from .modeling.mask_decoder_hq_matting import MaskDecoderHQMatting + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None, matting_token=0, wo_hq=False, frozen_decoder=False, mask_matting_res_add=True): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + matting_token=matting_token, + wo_hq=wo_hq, + frozen_decoder=frozen_decoder, + mask_matting_res_add=mask_matting_res_add + ) + + +def build_sam_vit_b(checkpoint=None, matting_token=False, wo_hq=False, frozen_decoder=False): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + matting_token=matting_token, + wo_hq=wo_hq, + frozen_decoder=frozen_decoder + ) + + +def build_sam_vit_t(checkpoint=None): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + mobile_sam = Sam( + image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8 + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoderHQ( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + vit_dim=160, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + + mobile_sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + device = "cuda" if torch.cuda.is_available() else "cpu" + state_dict = torch.load(f, map_location=device) + info = mobile_sam.load_state_dict(state_dict, strict=False) + print(info) + for n, p in mobile_sam.named_parameters(): + if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n: + p.requires_grad = False + return mobile_sam + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, + "vit_tiny": build_sam_vit_t +} + +def sam_model_registry_def(model_type, checkpoint, matting_token = 0, wo_hq = False, frozen_decoder = False, mask_matting_res_add=True): + assert model_type in {"default", "vit_h", "vit_l", "vit_b", "vit_tiny"} + return sam_model_registry[model_type](checkpoint=checkpoint, matting_token=matting_token, wo_hq=wo_hq, frozen_decoder=frozen_decoder, mask_matting_res_add=mask_matting_res_add) + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, + matting_token=0, + wo_hq=False, + frozen_decoder=False, + mask_matting_res_add=True +): + # no_res_add only work when wo_hq and have mat ting token + if not mask_matting_res_add: + assert matting_token > 0 + + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + + if matting_token > 0: + mask_decoder = MaskDecoderHQMatting( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + vit_dim=encoder_embed_dim, + wo_hq=wo_hq, + matting_token_num=matting_token, + mask_matting_res_add=mask_matting_res_add + ) + else: + mask_decoder = MaskDecoderHQ( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + vit_dim=encoder_embed_dim, + wo_hq=wo_hq + ) + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=mask_decoder, + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + device = "cuda" if torch.cuda.is_available() else "cpu" + state_dict = torch.load(f, map_location=device) + info = sam.load_state_dict(state_dict, strict=False) + print(info) + + if frozen_decoder and checkpoint is not None: + sam.frozen_mask_decoder = MaskDecoderHQ( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + vit_dim=encoder_embed_dim, + wo_hq=wo_hq + ) + sam.frozen_mask_decoder.eval() + info = sam.frozen_mask_decoder.load_state_dict({key.split('mask_decoder.')[1]: val for key, val in state_dict.items() if 'mask_decoder.' in key}, strict=False) + print('load frozen_mask_decoder', info) + # for n, p in sam.frozen_mask_decoder.named_parameters(): + # p = state_dict['mask_decoder.' + n] + + for n, p in sam.named_parameters(): + # if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n: + # p.requires_grad = False + if 'matting' not in n: + p.requires_grad = False + # p.requires_grad = False + + return sam diff --git a/modeling/semantic_enhanced_matting/build_sam_baseline.py b/modeling/semantic_enhanced_matting/build_sam_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d34d702821ef49dd451daa20bb3897e76357f2 --- /dev/null +++ b/modeling/semantic_enhanced_matting/build_sam_baseline.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +def build_sam_vit_t(checkpoint=None): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + mobile_sam = Sam( + image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8 + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + + mobile_sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + mobile_sam.load_state_dict(state_dict) + return mobile_sam + +sam_model_registry_baseline = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, + "vit_tiny": build_sam_vit_t +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam \ No newline at end of file diff --git a/modeling/semantic_enhanced_matting/condition_conv.py b/modeling/semantic_enhanced_matting/condition_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..c03b8e48ffa98b4cea4b57f6a95492c9cd6b9c33 --- /dev/null +++ b/modeling/semantic_enhanced_matting/condition_conv.py @@ -0,0 +1,504 @@ +import math +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from detectron2.layers import Conv2d +import fvcore.nn.weight_init as weight_init +from typing import Any, Optional, Tuple, Type + +from modeling.semantic_enhanced_matting.modeling.image_encoder import Attention +from modeling.semantic_enhanced_matting.modeling.transformer import Attention as DownAttention +from modeling.semantic_enhanced_matting.feature_fusion import PositionEmbeddingRandom as ImagePositionEmbedding +from modeling.semantic_enhanced_matting.modeling.common import MLPBlock + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class ConditionConv(nn.Module): + """ + The standard bottleneck residual block without the last activation layer. + It contains 3 conv layers with kernels 1x1, 3x3, 1x1. + """ + + def __init__( + self, + in_channels, + out_channels, + bottleneck_channels, + norm=LayerNorm2d, + act_layer=nn.GELU, + conv_kernels=3, + conv_paddings=1, + condtition_channels = 1024 + ): + """ + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + act_layer (callable): activation for all conv layers. + """ + super().__init__() + + self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False) + self.norm1 = norm(bottleneck_channels) + self.act1 = act_layer() + + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + conv_kernels, + padding=conv_paddings, + bias=False, + ) + self.norm2 = norm(bottleneck_channels) + self.act2 = act_layer() + + self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False) + self.norm3 = norm(out_channels) + + self.init_weight() + + self.condition_embedding = nn.Sequential( + act_layer(), + nn.Linear(condtition_channels, bottleneck_channels, bias=True) + ) + + def init_weight(self): + for layer in [self.conv1, self.conv2, self.conv3]: + weight_init.c2_msra_fill(layer) + for layer in [self.norm1, self.norm2]: + layer.weight.data.fill_(1.0) + layer.bias.data.zero_() + # zero init last norm layer. + self.norm3.weight.data.zero_() + self.norm3.bias.data.zero_() + + # def embed_bbox_and_instance(self, bbox, instance): + # assert isinstance(instance, bool) + + def forward(self, x, condition): + # [B, 64, 64, 1024] + out = x.permute(0, 3, 1, 2) + + out = self.act1(self.norm1(self.conv1(out))) + out = self.conv2(out) + self.condition_embedding(condition)[:, :, None, None] + out = self.act2(self.norm2(out)) + out = self.norm3(self.conv3(out)) + + out = x + out.permute(0, 2, 3, 1) + return out + + +class ConditionAdd(nn.Module): + def __init__( + self, + act_layer=nn.GELU, + condtition_channels = 1024 + ): + super().__init__() + + self.condition_embedding = nn.Sequential( + act_layer(), + nn.Linear(condtition_channels, condtition_channels, bias=True) + ) + + def forward(self, x, condition): + # [B, 64, 64, 1024] + condition = self.condition_embedding(condition)[:, None, None, :] + return x + condition + +class ConditionEmbedding(nn.Module): + def __init__( + self, + condition_num = 5, + pos_embedding_dim = 128, + embedding_scale = 1.0, + embedding_max_period = 10000, + embedding_flip_sin_to_cos = True, + embedding_downscale_freq_shift = 1.0, + time_embed_dim = 1024, + split_embed = False + ): + super().__init__() + self.condition_num = condition_num + self.pos_embedding_dim = pos_embedding_dim + self.embedding_scale = embedding_scale + self.embedding_max_period = embedding_max_period + self.embedding_flip_sin_to_cos = embedding_flip_sin_to_cos + self.embedding_downscale_freq_shift = embedding_downscale_freq_shift + self.split_embed = split_embed + + if self.split_embed: + self.linear_1 = nn.Linear(pos_embedding_dim, time_embed_dim, True) + else: + self.linear_1 = nn.Linear(condition_num * pos_embedding_dim, time_embed_dim, True) + self.act = nn.GELU() + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, True) + + def proj_embedding(self, condition): + sample = self.linear_1(condition) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + + def position_embedding(self, condition): + # [B, 5] --> [B, 5, 128] --> [B, 5 * 128] + + assert condition.shape[-1] == self.condition_num + + half_dim = self.pos_embedding_dim // 2 + exponent = -math.log(self.embedding_max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=condition.device + ) + exponent = exponent / (half_dim - self.embedding_downscale_freq_shift) + + emb = torch.exp(exponent) + emb = condition[:, :, None].float() * emb[None, None, :] # [B, 5, 1] * [1, 1, 64] --> [B, 5, 64] + + # scale embeddings + emb = self.embedding_scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # [B, 5, 64] --> [B, 5, 128] + + # flip sine and cosine embeddings + if self.embedding_flip_sin_to_cos: + emb = torch.cat([emb[:, :, half_dim:], emb[:, :, :half_dim]], dim=-1) + + # zero pad + # if self.pos_embedding_dim % 2 == 1: + # emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + if self.split_embed: + emb = emb.reshape(-1, emb.shape[-1]) + else: + emb = emb.reshape(emb.shape[0], -1) + + return emb + + def forward(self, condition): + condition = self.position_embedding(condition) + condition = self.proj_embedding(condition) + return condition.float() + + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + + self.positional_encoding_gaussian_matrix = nn.Parameter(scale * torch.randn((2, num_pos_feats // 2))) + # self.register_buffer( + # "positional_encoding_gaussian_matrix", + # scale * torch.randn((2, num_pos_feats)), + # ) + point_embeddings = [nn.Embedding(1, num_pos_feats) for i in range(2)] + self.point_embeddings = nn.ModuleList(point_embeddings) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + coords = self._pe_encoding(coords.to(torch.float)) # B x N x C + + coords[:, 0, :] += self.point_embeddings[0].weight + coords[:, 1, :] += self.point_embeddings[1].weight + + return coords + + +class CrossSelfAttn(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, embedding_dim=1024, num_heads=4, downsample_rate=4) -> None: + super().__init__() + + self.cross_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate) + self.norm1 = nn.LayerNorm(embedding_dim) + self.mlp = MLPBlock(embedding_dim, mlp_dim=512) + self.norm2 = nn.LayerNorm(embedding_dim) + self.self_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate) + self.norm3 = nn.LayerNorm(embedding_dim) + + def forward(self, block_feat, bbox_token, feat_pe, bbox_pe): + B, H, W, C = block_feat.shape + block_feat = block_feat.reshape(B, H * W, C) + + block_feat = block_feat + self.cross_attn(q=block_feat + feat_pe, k=bbox_token + bbox_pe, v=bbox_token) + block_feat = self.norm1(block_feat) + + block_feat = block_feat + self.mlp(block_feat) + block_feat = self.norm2(block_feat) + + concat_token = torch.concat((block_feat + feat_pe, bbox_token + bbox_pe), dim=1) + block_feat = block_feat + self.self_attn(q=concat_token, k=concat_token, v=concat_token)[:, :-bbox_token.shape[1]] + block_feat = self.norm3(block_feat) + output = block_feat.reshape(B, H, W, C) + + return output + + +class BBoxEmbedInteract(nn.Module): + def __init__( + self, + embed_type = 'fourier', + interact_type = 'attn', + layer_num = 3 + ): + super().__init__() + assert embed_type in {'fourier', 'position', 'conv'} + assert interact_type in {'add', 'attn', 'cross-self-attn'} + self.embed_type = embed_type + self.interact_type = interact_type + self.layer_num = layer_num + + if self.embed_type == 'fourier' and self.interact_type == 'add': + self.embed_layer = ConditionEmbedding(condition_num = 4, pos_embedding_dim = 256) + elif self.embed_type == 'fourier': + self.embed_layer = ConditionEmbedding(condition_num = 4, pos_embedding_dim = 256, split_embed = True) + elif self.embed_type == 'conv': + mask_in_chans = 16 + activation = nn.GELU + self.embed_layer = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, 1024, kernel_size=1), + ) + else: + if self.interact_type == 'add': + self.embed_layer = PositionEmbeddingRandom(num_pos_feats = 512) + else: + self.embed_layer = PositionEmbeddingRandom(num_pos_feats = 1024) + + self.interact_layer = nn.ModuleList() + for _ in range(self.layer_num): + if self.interact_type == 'attn': + self.interact_layer.append(Attention(dim = 1024)) + elif self.interact_type == 'add' and self.embed_type != 'conv': + self.interact_layer.append(nn.Sequential( + nn.GELU(), + nn.Linear(1024, 1024, bias=True) + )) + elif self.interact_type == 'cross-self-attn': + self.interact_layer.append(CrossSelfAttn(embedding_dim=1024, num_heads=4, downsample_rate=4)) + + self.position_layer = ImagePositionEmbedding(num_pos_feats=1024 // 2) + + def forward(self, block_feat, bbox, layer_index): + # input: [B, 1, 4], [B, 64, 64, 1024] + if layer_index == self.layer_num: + return block_feat + interact_layer = self.interact_layer[layer_index] + + bbox = bbox + 0.5 # Shift to center of pixel + if self.embed_type == 'fourier' and self.interact_type == 'add': + embedding = self.embed_layer(bbox[:, 0]) # [B, 1, 4] --> reshape [B, 4] --> [B, 1024 * 1] --> reshape [B, 1, 1024] + embedding = embedding.reshape(embedding.shape[0], 1, -1) + elif self.embed_type == 'fourier': + embedding = self.embed_layer(bbox[:, 0]) # [B, 1, 4] --> reshape [B, 4] --> [B, 1024 * 4] --> reshape [B, 4, 1024] + embedding = embedding.reshape(-1, 4, embedding.shape[-1]) + elif self.embed_type == 'conv': + # concat mask and img as condition + bbox_mask = torch.zeros(size=(block_feat.shape[0], 1, 256, 256), device=block_feat.device, dtype=block_feat.dtype) # [B, 1, 512, 512] + for i in range(bbox.shape[0]): + l, u, r, d = bbox[i, 0, :] / 4 + bbox_mask[i, :, int(u + 0.5): int(d + 0.5), int(l + 0.5): int(r + 0.5)] = 1.0 # int(x + 0.5) = round(x) + embedding = self.embed_layer(bbox_mask) # [B, 1024, 64, 64] + elif self.embed_type == 'position': + embedding = self.embed_layer(bbox.reshape(-1, 2, 2), (1024, 1024)) # [B, 1, 4] --> reshape [B, 2, 2] --> [B, 2, 1024/512] + if self.interact_type == 'add': + embedding = embedding.reshape(embedding.shape[0], 1, -1) + + # add position embedding to block_feat + pe = self.position_layer(size=(64, 64)).reshape(1, 64, 64, 1024) + block_feat = block_feat + pe + + if self.interact_type == 'attn': + add_token_num = embedding.shape[1] + B, H, W, C = block_feat.shape + block_feat = block_feat.reshape(B, H * W, C) + concat_token = torch.concat((block_feat, embedding), dim=1) # [B, 64 * 64 + 2, 1024] + output_token = interact_layer.forward_token(concat_token)[:, :-add_token_num] + output = output_token.reshape(B, H, W, C) + elif self.embed_type == 'conv': + output = block_feat + embedding.permute(0, 2, 3, 1) + elif self.interact_type == 'add': + output = interact_layer(embedding[:, None]) + block_feat + elif self.interact_type == 'cross-self-attn': + output = interact_layer(block_feat, embedding) + + return output + + +# reuse the position_point_embedding in prompt_encoder +class BBoxInteract(nn.Module): + def __init__( + self, + position_point_embedding, + point_weight, + layer_num = 3, + ): + super().__init__() + + self.position_point_embedding = position_point_embedding + self.point_weight = point_weight + for _, p in self.named_parameters(): + p.requires_grad = False + + self.layer_num = layer_num + self.input_image_size = (1024, 1024) + + self.interact_layer = nn.ModuleList() + for _ in range(self.layer_num): + self.interact_layer.append(CrossSelfAttn(embedding_dim=1024, num_heads=4, downsample_rate=4)) + + @torch.no_grad() + def get_bbox_token(self, boxes): + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.position_point_embedding.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_weight[2].weight + corner_embedding[:, 1, :] += self.point_weight[3].weight + corner_embedding = F.interpolate(corner_embedding[..., None], size=(1024, 1), mode='bilinear', align_corners=False)[..., 0] + return corner_embedding # [B, 2, 1024] + + @torch.no_grad() + def get_position_embedding(self, size=(64, 64)): + pe = self.position_point_embedding(size=size) + pe = F.interpolate(pe.permute(1, 2, 0)[..., None], size=(1024, 1), mode='bilinear', align_corners=False)[..., 0][None] + pe = pe.reshape(1, -1, 1024) + return pe # [1, 64 * 64, 1024] + + def forward(self, block_feat, bbox, layer_index): + # input: [B, 1, 4], [B, 64, 64, 1024] + if layer_index == self.layer_num: + return block_feat + interact_layer = self.interact_layer[layer_index] + + pe = self.get_position_embedding() + bbox_token = self.get_bbox_token(bbox) + + output = interact_layer(block_feat, bbox_token, feat_pe=pe, bbox_pe=bbox_token) + + return output + +class InOutBBoxCrossSelfAttn(nn.Module): + + def __init__(self, embedding_dim=1024, num_heads=4, downsample_rate=4) -> None: + super().__init__() + + self.self_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate) + self.norm1 = nn.LayerNorm(embedding_dim) + self.mlp = MLPBlock(embedding_dim, mlp_dim=embedding_dim // 2) + self.norm2 = nn.LayerNorm(embedding_dim) + self.cross_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate) + self.norm3 = nn.LayerNorm(embedding_dim) + + def forward(self, in_box_token, out_box_token): + + # self-attn + short_cut = in_box_token + in_box_token = self.norm1(in_box_token) + in_box_token = self.self_attn(q=in_box_token, k=in_box_token, v=in_box_token) + in_box_token = short_cut + in_box_token + + # mlp + in_box_token = in_box_token + self.mlp(self.norm2(in_box_token)) + + # cross-attn + short_cut = in_box_token + in_box_token = self.norm3(in_box_token) + in_box_token = self.cross_attn(q=in_box_token, k=out_box_token, v=out_box_token) + in_box_token = short_cut + in_box_token + + return in_box_token + + +class BBoxInteractInOut(nn.Module): + def __init__( + self, + num_heads = 4, + downsample_rate = 4, + layer_num = 3, + ): + super().__init__() + + self.layer_num = layer_num + self.input_image_size = (1024, 1024) + + self.interact_layer = nn.ModuleList() + for _ in range(self.layer_num): + self.interact_layer.append(InOutBBoxCrossSelfAttn(embedding_dim=1024, num_heads=num_heads, downsample_rate=downsample_rate)) + + def forward(self, block_feat, bbox, layer_index): + + # input: [B, 1, 4], [B, 64, 64, 1024] + if layer_index == self.layer_num: + return block_feat + interact_layer = self.interact_layer[layer_index] + + # split_in_out_bbox_token + bbox = torch.round(bbox / self.input_image_size[0] * (block_feat.shape[1] - 1)).int() + for i in range(block_feat.shape[0]): + in_bbox_mask = torch.zeros((block_feat.shape[1], block_feat.shape[2]), dtype=bool, device=bbox.device) + in_bbox_mask[bbox[i, 0, 1]: bbox[i, 0, 3], bbox[i, 0, 0]: bbox[i, 0, 2]] = True + in_bbox_token = block_feat[i: i + 1, in_bbox_mask, :] + out_bbox_token = block_feat[i: i + 1, ~in_bbox_mask, :] + block_feat[i, in_bbox_mask, :] = interact_layer(in_bbox_token, out_bbox_token) + + return block_feat + + +if __name__ == '__main__': + # emded = ConditionEmbedding() + # input = torch.tensor([[100, 200, 300, 400, 512], [100, 200, 300, 400, 1024]]) + # print(input.shape) + # output = emded(input) # [B, 5] --> [B, 5 * 128] --> [B, 1024] + + emded = BBoxEmbedInteract( + embed_type = 'position', + interact_type = 'cross-self-attn' + ) + input = torch.tensor([[[100, 200, 300, 400]], [[100, 200, 300, 400]]]) # [B, 1, 4] + print(input.shape) + output = emded(torch.randn((2, 64, 64, 1024)), input) # [B, 5] --> [B, 5 * 128] --> [B, 1024] \ No newline at end of file diff --git a/modeling/semantic_enhanced_matting/feature_fusion.py b/modeling/semantic_enhanced_matting/feature_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..27970af201f9bb3769027f9674e6be955a333080 --- /dev/null +++ b/modeling/semantic_enhanced_matting/feature_fusion.py @@ -0,0 +1,283 @@ +import torch +import torch.nn as nn +from typing import Type, Optional, Tuple +import numpy as np + +from .modeling.transformer import Attention +from .modeling.common import MLPBlock +# from modeling.transformer import Attention +# from modeling.common import MLPBlock + + +class MutualCrossAttention(nn.Module): + def __init__( + self, + embedding_dim: int = 1024, + num_heads: int = 8, + mlp_dim: int = 1024, + activation: Type[nn.Module] = nn.GELU, + attention_downsample_rate: int = 4, + ) -> None: + super().__init__() + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.norm3 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + def forward(self, queries, keys, query_pe=None, key_pe=None): + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe if query_pe is not None else queries + k = keys + key_pe if key_pe is not None else keys + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm1(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm2(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe if query_pe is not None else queries + k = keys + key_pe if key_pe is not None else keys + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm3(keys) + + return queries, keys + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + # return pe.permute(2, 0, 1) # C x H x W + return pe.reshape(h * w, -1)[None] # 1 x (H x W) x C + + +class FeatureFusion(nn.Module): + def __init__( + self, + in_channels=1024, + input_compression_ratio=1, + attn_compression_ratio=4, + features_num=4, + w_pe=True, + ): + super().__init__() + + self.input_compression_ratio = input_compression_ratio + if self.input_compression_ratio != 1: + self.mlp_in = nn.ModuleList([nn.Sequential( + nn.Linear(in_channels, in_channels // input_compression_ratio), + # activation(), + # nn.Linear(embedding_dim // compression_ratio, embedding_dim // compression_ratio) + ) for _ in range(features_num)]) + + self.mlp_out = nn.ModuleList([nn.Sequential( + nn.Linear(in_channels // input_compression_ratio, in_channels), + # activation(), + # nn.Linear(embedding_dim, embedding_dim) + ) for _ in range(features_num)]) + + in_channels = in_channels // input_compression_ratio + self.mutual_cross_attn = nn.ModuleList([ + MutualCrossAttention(embedding_dim=in_channels, mlp_dim=in_channels // attn_compression_ratio, attention_downsample_rate=attn_compression_ratio) for _ in range(features_num - 1) + ]) + self.w_pe = w_pe + if self.w_pe: + # no grad + self.get_pe = PositionEmbeddingRandom(in_channels // 2) + with torch.no_grad(): + self.pe = self.get_pe(size=(64, 64)) + + def forward(self, features): + # [B, 64, 64, 1024] x 4 + + b, h, w, _ = features[0].shape + for i in range(len(features)): + features[i] = features[i].reshape(b, h * w, -1) + if self.input_compression_ratio != 1: + features[i] = self.mlp_in[i](features[i]) + + for i in range(len(features) - 1): + features[i], features[i + 1] = self.mutual_cross_attn[i](features[i], features[i + 1], self.pe, self.pe) + + for i in range(len(features)): + features[i] = features[i].reshape(b, h, w, -1) + if self.input_compression_ratio != 1: + features[i] = self.mlp_out[i](features[i]) + + return features + + +if __name__ == '__main__': + + import typing + from collections import defaultdict + import tabulate + from torch import nn + + + def parameter_count(model: nn.Module, trainable_only: bool = False) -> typing.DefaultDict[str, int]: + """ + Count parameters of a model and its submodules. + + Args: + model: a torch module + + Returns: + dict (str-> int): the key is either a parameter name or a module name. + The value is the number of elements in the parameter, or in all + parameters of the module. The key "" corresponds to the total + number of parameters of the model. + """ + r = defaultdict(int) + for name, prm in model.named_parameters(): + if trainable_only: + if not prm.requires_grad: + continue + size = prm.numel() + name = name.split(".") + for k in range(0, len(name) + 1): + prefix = ".".join(name[:k]) + r[prefix] += size + return r + + + def parameter_count_table( + model: nn.Module, max_depth: int = 3, trainable_only: bool = False + ) -> str: + """ + Format the parameter count of the model (and its submodules or parameters) + in a nice table. It looks like this: + + :: + + | name | #elements or shape | + |:--------------------------------|:---------------------| + | model | 37.9M | + | backbone | 31.5M | + | backbone.fpn_lateral3 | 0.1M | + | backbone.fpn_lateral3.weight | (256, 512, 1, 1) | + | backbone.fpn_lateral3.bias | (256,) | + | backbone.fpn_output3 | 0.6M | + | backbone.fpn_output3.weight | (256, 256, 3, 3) | + | backbone.fpn_output3.bias | (256,) | + | backbone.fpn_lateral4 | 0.3M | + | backbone.fpn_lateral4.weight | (256, 1024, 1, 1) | + | backbone.fpn_lateral4.bias | (256,) | + | backbone.fpn_output4 | 0.6M | + | backbone.fpn_output4.weight | (256, 256, 3, 3) | + | backbone.fpn_output4.bias | (256,) | + | backbone.fpn_lateral5 | 0.5M | + | backbone.fpn_lateral5.weight | (256, 2048, 1, 1) | + | backbone.fpn_lateral5.bias | (256,) | + | backbone.fpn_output5 | 0.6M | + | backbone.fpn_output5.weight | (256, 256, 3, 3) | + | backbone.fpn_output5.bias | (256,) | + | backbone.top_block | 5.3M | + | backbone.top_block.p6 | 4.7M | + | backbone.top_block.p7 | 0.6M | + | backbone.bottom_up | 23.5M | + | backbone.bottom_up.stem | 9.4K | + | backbone.bottom_up.res2 | 0.2M | + | backbone.bottom_up.res3 | 1.2M | + | backbone.bottom_up.res4 | 7.1M | + | backbone.bottom_up.res5 | 14.9M | + | ...... | ..... | + + Args: + model: a torch module + max_depth (int): maximum depth to recursively print submodules or + parameters + + Returns: + str: the table to be printed + """ + count: typing.DefaultDict[str, int] = parameter_count(model, trainable_only) + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. + param_shape: typing.Dict[str, typing.Tuple] = { + k: tuple(v.shape) for k, v in model.named_parameters() + } + + # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. + table: typing.List[typing.Tuple] = [] + + def format_size(x: int) -> str: + if x > 1e8: + return "{:.1f}G".format(x / 1e9) + if x > 1e5: + return "{:.1f}M".format(x / 1e6) + if x > 1e2: + return "{:.1f}K".format(x / 1e3) + return str(x) + + def fill(lvl: int, prefix: str) -> None: + if lvl >= max_depth: + return + for name, v in count.items(): + if name.count(".") == lvl and name.startswith(prefix): + indent = " " * (lvl + 1) + if name in param_shape: + table.append((indent + name, indent + str(param_shape[name]))) + else: + table.append((indent + name, indent + format_size(v))) + fill(lvl + 1, name + ".") + + table.append(("model", format_size(count.pop("")))) + fill(0, "") + + old_ws = tabulate.PRESERVE_WHITESPACE + tabulate.PRESERVE_WHITESPACE = True + tab = tabulate.tabulate(table, headers=["name", "#elements or shape"], tablefmt="pipe") + tabulate.PRESERVE_WHITESPACE = old_ws + return tab + + feature_fusion = FeatureFusion(in_channels=1024, attn_compression_ratio=8) + print("All parameters: \n" + parameter_count_table(feature_fusion, max_depth=8)) + features = [torch.randn(2, 64, 64, 1024) for _ in range(4)] + out = feature_fusion(features) + for i in out: + print(i.shape) + print('done') diff --git a/modeling/semantic_enhanced_matting/mask_decoder_matting.py b/modeling/semantic_enhanced_matting/mask_decoder_matting.py new file mode 100644 index 0000000000000000000000000000000000000000..1378f5445f2e25aab117042a5337ca41af4b0cdd --- /dev/null +++ b/modeling/semantic_enhanced_matting/mask_decoder_matting.py @@ -0,0 +1,356 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Tuple +import numpy as np +import cv2 +from detectron2.layers.batch_norm import NaiveSyncBatchNorm + +from modeling.semantic_enhanced_matting.modeling import TwoWayTransformer, MaskDecoder +from modeling.decoder.detail_capture import Detail_Capture +from modeling.decoder.unet_detail_capture import DetailUNet +# from nnMorpho.binary_operators import erosion + + +# class GenTrimapTorch(object): +# def __init__(self, max_kernal=200): +# self.max_kernal = max_kernal +# self.erosion_kernels = [None] + [torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size))).float().cuda() for size in range(1, self.max_kernal)] + +# def __call__(self, mask, kernel_size): + +# fg_width = kernel_size +# bg_width = kernel_size + +# fg_mask = mask +# bg_mask = 1 - mask + +# fg_mask = erosion(fg_mask, self.erosion_kernels[fg_width], border='a') +# bg_mask = erosion(bg_mask, self.erosion_kernels[bg_width], border='a') + +# trimap = torch.ones_like(mask) * 0.5 +# trimap[fg_mask == 1] = 1.0 +# trimap[bg_mask == 1] = 0.0 + +# return trimap + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + +class MaskDecoderMatting(MaskDecoder): + def __init__( + self, + model_type, + checkpoint_path, + detail_capture, + mask_token_only, + norm_type = 'LN', + norm_mask_logits = False, + with_trimap = False, + min_kernel_size = 20, + kernel_div = 10, + concat_gen_trimap = False, + ): + super().__init__( + transformer_dim=256, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=256, + mlp_dim=2048, + num_heads=8, + ), + num_multimask_outputs=3, + activation=nn.GELU, + iou_head_depth=3, + iou_head_hidden_dim=256, + ) + assert model_type in ["vit_b","vit_l","vit_h"] + + assert norm_type in {'BN', 'LN', 'SyncBN'} + if norm_type == 'BN': + self.norm = torch.nn.BatchNorm2d + elif norm_type == 'SyncBN': + self.norm = NaiveSyncBatchNorm + else: + self.norm = LayerNorm2d + + # checkpoint_dict = {"vit_b":"pretrained_checkpoint/sam_vit_b_maskdecoder.pth", + # "vit_l":"pretrained_checkpoint/sam_vit_l_maskdecoder.pth", + # 'vit_h':"pretrained_checkpoint/sam_vit_h_maskdecoder.pth"} + # checkpoint_path = checkpoint_dict[model_type] + + self.load_state_dict(torch.load(checkpoint_path)) + print("Matting Decoder init from SAM MaskDecoder") + + self.frozen_params_str = set() + for n, p in self.named_parameters(): + p.requires_grad = False + self.frozen_params_str.add(n) + + self.detail_capture = detail_capture + self.mask_token_only = mask_token_only + self.norm_mask_logits = norm_mask_logits + + transformer_dim = 256 + vit_dim_dict = {"vit_b":768,"vit_l":1024,"vit_h":1280} + vit_dim = vit_dim_dict[model_type] + + self.hf_token = nn.Embedding(1, transformer_dim) + self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + self.num_mask_tokens = self.num_mask_tokens + 1 + self.concat_gen_trimap = concat_gen_trimap + + self.compress_vit_feat = nn.Sequential( + nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2), + self.norm(transformer_dim), + nn.GELU(), + nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2) + ) + + self.embedding_encoder = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + self.norm(transformer_dim // 4), + nn.GELU(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + ) + + self.embedding_maskfeature = nn.Sequential( + nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1), + self.norm(transformer_dim // 4), + nn.GELU(), + nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1) + ) + + if isinstance(self.detail_capture, Detail_Capture): + self.glue_layer_0 = nn.Conv2d(self.detail_capture.fus_channs[2], transformer_dim // 8, 3, 1, 1) + else: + assert isinstance(self.detail_capture, DetailUNet) + + self.trainable_params_str = set() + for n, p in self.named_parameters(): + if p.requires_grad: + self.trainable_params_str.add(n) + + self.with_trimap = with_trimap + self.min_kernel_size = min_kernel_size + self.kernel_div = kernel_div + if self.with_trimap and not self.concat_gen_trimap: + # self.gen_trimap = GenTrimapTorch() + raise ValueError('Discard GenTrimapTorch') + + # self.trainable_params_str = {'detail_capture', 'hf_token', 'hf_mlp', 'compress_vit_feat', 'embedding_encoder', 'embedding_maskfeature', 'glue_layer_0'} + # for n, p in self.named_parameters(): + # if p.requires_grad: + # assert n.split('.')[0] in self.trainable_params_str + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + # hq_token_only: bool, + interm_embeddings: torch.Tensor, + hq_features: torch.Tensor, + images: torch.Tensor, + hr_images_ori_h_w = None, + return_alpha_logits = False, + pred_trimap=None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the ViT image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted hq masks + """ + vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT [B, 64, 64, 768] + + # upsample image_embeddings x4.0 with detail_capture & embedding_encoder & compress_vit_feat + # regard hq_features as condition + if isinstance(self.norm_mask_logits, float): + norm_hq_features = hq_features / self.norm_mask_logits + elif self.norm_mask_logits: + norm_hq_features = hq_features / torch.std(hq_features, dim=(1, 2, 3), keepdim=True) + else: + norm_hq_features = hq_features + + if hr_images_ori_h_w is not None: + assert not isinstance(self.detail_capture, Detail_Capture) and hq_features.shape[-2] == hq_features.shape[-1] == 1024 + lr_images_before_pad_h_w = (1024 / max(hr_images_ori_h_w) * hr_images_ori_h_w[0], 1024 / max(hr_images_ori_h_w) * hr_images_ori_h_w[1]) + lr_images_before_pad_h_w = (int(lr_images_before_pad_h_w[0] + 0.5), int(lr_images_before_pad_h_w[1] + 0.5)) + norm_hq_features = F.interpolate( + norm_hq_features[:, :, :lr_images_before_pad_h_w[0], :lr_images_before_pad_h_w[1]], + size = (images.shape[-2], images.shape[-1]), + mode = 'bilinear', + align_corners = False + ) + + if self.concat_gen_trimap: + pred_trimap = F.interpolate(pred_trimap, size=(images.shape[-2], images.shape[-1]), mode='bilinear', align_corners=False) + pred_trimap = torch.argmax(pred_trimap, dim=1, keepdim=True).float() / 2.0 + norm_hq_features = torch.concat((norm_hq_features, pred_trimap), dim=1) + elif self.with_trimap: + mask = (norm_hq_features > 0).float() + for i_batch in range(image_embeddings.shape[0]): + mask_area = torch.sum(mask[i_batch]) + kernel_size = max(self.min_kernel_size, int((mask_area ** 0.5) / self.kernel_div)) + kernel_size = min(kernel_size, self.gen_trimap.max_kernal - 1) + mask[i_batch, 0] = self.gen_trimap(mask[i_batch, 0], kernel_size=kernel_size) + trimaps = mask + norm_hq_features = torch.concat((norm_hq_features, trimaps), dim=1) + + conditional_images = torch.concatenate((images, norm_hq_features), dim=1) # [B, 4, 1024, 1024] + + if isinstance(self.detail_capture, Detail_Capture): + detail_features = self.detail_capture.convstream(conditional_images) # [B, 4, 1024, 1024] --> D0: [B, 4, 1024, 1024], D1: [B, 48, 512, 512], D2: [B, 96, 256, 256], D3: [B, 192, 128, 128] + matting_features = self.detail_capture.fusion_blks[0](image_embeddings, detail_features['D3']) # [B, 256, 64, 64] & [B, 192, 128, 128] --> [B, 256, 128, 128] + matting_features = self.detail_capture.fusion_blks[1](matting_features, detail_features['D2']) # [B, 256, 128, 128] & [B, 96, 256, 256] --> [B, 128, 256, 256] + matting_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features) + self.glue_layer_0(matting_features) # [B, 32, 256, 256] + else: + matting_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features) + + batch_len = len(image_embeddings) + masks = [] + iou_preds = [] + for i_batch in range(batch_len): + mask, iou_pred = self.predict_masks( + image_embeddings=image_embeddings[i_batch].unsqueeze(0), + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings[i_batch], + dense_prompt_embeddings=dense_prompt_embeddings[i_batch], + matting_feature = matting_features[i_batch].unsqueeze(0) + ) + masks.append(mask) + iou_preds.append(iou_pred) + masks = torch.cat(masks, 0) # [B, 5, 256, 256] + iou_preds = torch.cat(iou_preds, 0) # [4, 4] + + if self.mask_token_only: + masks_matting = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens), :, :] # [B, 1, 256, 256] + else: + masks_matting = masks # [B, 5, 256, 256] + + if hr_images_ori_h_w is not None: + vit_features = F.interpolate( + vit_features[:, :, :math.ceil(lr_images_before_pad_h_w[0] / 16), :math.ceil(lr_images_before_pad_h_w[1] / 16)], + size = (images.shape[-2] // 16, images.shape[-1] // 16), + mode = 'bilinear', + align_corners = False + ) + masks_matting = F.interpolate( + masks_matting[:, :, :math.ceil(lr_images_before_pad_h_w[0] / 4), :math.ceil(lr_images_before_pad_h_w[1] / 4)], + size = (images.shape[-2] // 4, images.shape[-1] // 4), + mode = 'bilinear', + align_corners = False + ) + + if isinstance(self.detail_capture, Detail_Capture): + matting_features = self.detail_capture.fusion_blks[2](masks_matting, detail_features['D1']) + matting_features = self.detail_capture.fusion_blks[3](matting_features, detail_features['D0']) + alpha = torch.sigmoid(self.detail_capture.matting_head(matting_features)) + else: + if return_alpha_logits: + output = self.detail_capture(conditional_images, vit_features, masks_matting, return_alpha_logits = True) + alpha = torch.sigmoid(output[0]), output[1] + else: + alpha = torch.sigmoid(self.detail_capture(conditional_images, vit_features, masks_matting, return_alpha_logits = False)) + + if hr_images_ori_h_w is not None: + alpha = alpha[:, :, :hr_images_ori_h_w[0], :hr_images_ori_h_w[1]] + + return alpha + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + matting_feature: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0) # [6, 256] + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) # [1, 6, 256] + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # [1, 8, 256] + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) # [1, 256, 64, 64] + src = src + dense_prompt_embeddings # [1, 256, 64, 64] + [1, 256, 64, 64] + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) # [1, 256, 64, 64] + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) # [1, 8, 256], [1, 4096, 256] + iou_token_out = hs[:, 0, :] # [1, 256] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # [1, 5, 256] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) # [1, 256, 64, 64] + + upscaled_embedding_sam = self.output_upscaling(src) # [1, 32, 256, 256] + upscaled_embedding_ours = self.embedding_maskfeature(upscaled_embedding_sam) + matting_feature # [1, 32, 256, 256] + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + if i < 4: + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + else: + hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :])) + + hyper_in = torch.stack(hyper_in_list, dim=1) # 5 * [1, 32] --> [1, 5, 32] + b, c, h, w = upscaled_embedding_sam.shape + + masks_sam = (hyper_in[:,:4] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) # [1, 4, 32] @ [1, 32, 65536] --> [1, 4, 256, 256] + masks_ours = (hyper_in[:,4:] @ upscaled_embedding_ours.view(b, c, h * w)).view(b, -1, h, w) # [1, 1, 32] @ [1, 32, 65536] --> [1, 1, 256, 256] + masks = torch.cat([masks_sam,masks_ours], dim=1) # [1, 5, 256, 256] + + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred \ No newline at end of file diff --git a/modeling/semantic_enhanced_matting/modeling/__init__.py b/modeling/semantic_enhanced_matting/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa0a07c7f8b75a3d1882bd4e7a4a3bc83e9da51c --- /dev/null +++ b/modeling/semantic_enhanced_matting/modeling/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder_hq import MaskDecoderHQ +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer +from .tiny_vit_sam import TinyViT diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/__init__.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fab22e68038c55b95e55130303275bf62aa1d8c0 Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/__init__.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/common.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/common.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49b83405778545969ce45fae770bf5d0b6c4f7cc Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/common.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/image_encoder.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/image_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0984055070520495df8c875e099f0a77ad3e84e5 Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/image_encoder.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec75a3ace2bf1978f084adfdf05be12e0c0ebdbb Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder_hq.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder_hq.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ede602c19484a67e6ea264872a2a43bb849ca332 Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder_hq.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder_hq_matting.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder_hq_matting.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8819dd5b191c831acad58b371c383a06a51dc559 Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder_hq_matting.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/prompt_encoder.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/prompt_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..829c46dbb26b33310f8538b01e1bf503df803928 Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/prompt_encoder.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/sam.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/sam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..181fefe93e565ea8dfd5344e05444dcfd414ef80 Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/sam.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/tiny_vit_sam.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/tiny_vit_sam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f02aa71dfd72b5235064230fd965b9a21704820b Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/tiny_vit_sam.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/transformer.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3809a541f8d339f52f0d6c3f037fa2212e5ec10d Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/transformer.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/modeling/common.py b/modeling/semantic_enhanced_matting/modeling/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96 --- /dev/null +++ b/modeling/semantic_enhanced_matting/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/modeling/semantic_enhanced_matting/modeling/image_encoder.py b/modeling/semantic_enhanced_matting/modeling/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9353c9c0c63cf7447ff5fd14fca7a0e9ee74ce17 --- /dev/null +++ b/modeling/semantic_enhanced_matting/modeling/image_encoder.py @@ -0,0 +1,447 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_size = patch_size + self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, img: torch.Tensor, condition: torch.Tensor = None, condition_layer = None) -> torch.Tensor: + + x = self.patch_embed(img) + if isinstance(condition_layer, nn.ModuleDict) and condition is not None: + # concat mask and img as condition + bbox_mask = torch.zeros_like(img)[:, 0:1] + for i in range(condition.shape[0]): + l, u, r, d = condition[i, 0, :] + bbox_mask[i, :, int(u): int(d), int(l): int(r)] = 1.0 + condition_input = torch.concat((img, bbox_mask), dim=1) + + x = x + condition_layer['patch_embed'](condition_input) + + if self.pos_embed is not None: + x = x + self.pos_embed + + index = 0 + interm_embeddings = [] + pred_trimap = [] + for blk in self.blocks: + x = blk(x) + if blk.window_size == 0: + + interm_embeddings.append(x) + + # pred intern triamp + if isinstance(condition_layer, nn.ModuleDict) and '{}_pred_layer'.format(index) in condition_layer.keys() and condition is not None: + pred_trimap.append(condition_layer['{}_pred_layer'.format(index)](x.permute(0, 3, 1, 2))) + + # add intern prompt + if isinstance(condition_layer, nn.ModuleList): + x = condition_layer[index](x, condition) + elif isinstance(condition_layer, nn.ModuleDict) and 'prompt_layer' in condition_layer.keys() and condition is not None: + x = x + condition_layer['prompt_layer'](x, condition, index) + + index += 1 + + x = self.neck(x.permute(0, 3, 1, 2)) + + if isinstance(condition_layer, nn.ModuleDict) and len(pred_trimap) != 0 and condition is not None: + pred_trimap = sum(pred_trimap) / len(pred_trimap) + pred_trimap = F.interpolate(pred_trimap, size=(img.shape[-2], img.shape[-1]), mode='bilinear', align_corners=False) + else: + pred_trimap = None + + if isinstance(condition_layer, nn.ModuleDict) and 'feature_fusion' in condition_layer.keys() and condition is not None: + interm_embeddings = condition_layer['feature_fusion'](interm_embeddings) + + return x, interm_embeddings, pred_trimap + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward_token(self, x: torch.Tensor) -> torch.Tensor: + B, N, _ = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.reshape(3, B * self.num_heads, N, -1).unbind(0) + attn = (q * self.scale) @ k.transpose(-2, -1) + assert not self.use_rel_pos + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, N, -1).permute(0, 2, 1, 3).reshape(B, N, -1) + x = self.proj(x) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/modeling/semantic_enhanced_matting/modeling/mask_decoder.py b/modeling/semantic_enhanced_matting/modeling/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..caddba072321a2234b9185e70de1abc6799c6f9c --- /dev/null +++ b/modeling/semantic_enhanced_matting/modeling/mask_decoder.py @@ -0,0 +1,178 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + hq_token_only: bool, + interm_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/modeling/semantic_enhanced_matting/modeling/mask_decoder_hq.py b/modeling/semantic_enhanced_matting/modeling/mask_decoder_hq.py new file mode 100644 index 0000000000000000000000000000000000000000..f4bc9d8096e3d140422d7f60a1b5860184b53c15 --- /dev/null +++ b/modeling/semantic_enhanced_matting/modeling/mask_decoder_hq.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Modified by HQ-SAM team +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoderHQ(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + vit_dim: int = 1024, + w_all_logits: bool = False, + wo_hq: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + self.vit_dim = vit_dim + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + self.wo_hq = wo_hq + if not self.wo_hq: + # HQ-SAM parameters + self.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Token + self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) # corresponding new MLP layer for HQ-Ouptput-Token + self.num_mask_tokens = self.num_mask_tokens + 1 + + # three conv fusion layers for obtaining HQ-Feature + self.compress_vit_feat = nn.Sequential( + nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim), + nn.GELU(), + nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2)) + + self.embedding_encoder = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + nn.GELU(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + ) + self.embedding_maskfeature = nn.Sequential( + nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1), + LayerNorm2d(transformer_dim // 4), + nn.GELU(), + nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1)) + + self.w_all_logits = w_all_logits + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + hq_token_only: bool, + interm_embeddings: torch.Tensor, + return_hq_features_type: str, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the ViT image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + if not self.wo_hq: + vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT + hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features) + return_hq_features = None + if return_hq_features_type == 'Early': + return_hq_features = hq_features + else: + hq_features = None + + masks, iou_pred, mid_fin_hq_features = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + hq_features=hq_features, + return_hq_features_type=return_hq_features_type + ) + if return_hq_features_type in {'Middle', 'Final'}: + return_hq_features = mid_fin_hq_features + + # Select the correct mask or masks for output + if multimask_output: + # mask with highest score + if not self.wo_hq: + mask_slice = slice(1, self.num_mask_tokens-1) + else: + mask_slice = slice(1, self.num_mask_tokens) + iou_pred = iou_pred[:, mask_slice] + iou_pred, max_iou_idx = torch.max(iou_pred,dim=1) + iou_pred = iou_pred.unsqueeze(1) + masks_multi = masks[:, mask_slice, :, :] + masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1) + else: + # singale mask output, default + mask_slice = slice(0, 1) + iou_pred = iou_pred[:,mask_slice] + masks_sam = masks[:,mask_slice] + + if not self.wo_hq: + masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens)] + + if hq_token_only: + low_res_masks = masks_hq + else: + low_res_masks = masks_sam + masks_hq + + if return_hq_features_type == 'Final': + return_hq_features = low_res_masks + + if self.w_all_logits: + return_hq_features = masks + else: + low_res_masks = masks_sam + return_hq_features = masks_sam + # Prepare output + return low_res_masks, masks_sam, return_hq_features + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + hq_features: torch.Tensor, + return_hq_features_type: str + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + if not self.wo_hq: + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0) + else: + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + + upscaled_embedding_sam = self.output_upscaling(src) + if not self.wo_hq: + upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(b,1,1,1) + return_hq_features = None + if return_hq_features_type == 'Middle': + return_hq_features = upscaled_embedding_hq + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + if i < self.num_mask_tokens - 1 or self.wo_hq: + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + else: + hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :])) + + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding_sam.shape + if not self.wo_hq: + masks_sam = (hyper_in[:,:self.num_mask_tokens-1] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) + masks_sam_hq = (hyper_in[:,self.num_mask_tokens-1:] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w) + if return_hq_features_type == 'Final': + return_hq_features = masks_sam_hq + masks = torch.cat([masks_sam, masks_sam_hq],dim=1) + else: + masks_sam = (hyper_in @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) + return_hq_features = masks_sam + masks = masks_sam + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred, return_hq_features + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/modeling/semantic_enhanced_matting/modeling/mask_decoder_hq_matting.py b/modeling/semantic_enhanced_matting/modeling/mask_decoder_hq_matting.py new file mode 100644 index 0000000000000000000000000000000000000000..59ff8f44cb1419ebcbcd957a5b8119afba74e1d7 --- /dev/null +++ b/modeling/semantic_enhanced_matting/modeling/mask_decoder_hq_matting.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Modified by HQ-SAM team +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d +from .mask_decoder_hq import MaskDecoderHQ, MLP + + +class MaskDecoderHQMatting(MaskDecoderHQ): + def __init__( + self, + hq_token_only=False, + matting_token_num=1, + mask_matting_res_add=True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.hq_token_only = hq_token_only + self.matting_token_num = matting_token_num + self.mask_matting_res_add = mask_matting_res_add + if not self.mask_matting_res_add: + assert self.wo_hq + + # Matting token parameters + self.matting_hf_token = nn.Embedding(self.matting_token_num, self.transformer_dim) # Matting-Ouptput-Token + self.matting_hf_mlp = MLP(self.transformer_dim, self.transformer_dim, self.transformer_dim // 8, 3) # corresponding new MLP layer for Matting-Ouptput-Token + self.num_mask_tokens = self.num_mask_tokens + self.matting_token_num + + # three conv fusion layers for obtaining Matting-Feature + self.matting_compress_vit_feat = nn.Sequential( + nn.ConvTranspose2d(self.vit_dim, self.transformer_dim, kernel_size=2, stride=2), + LayerNorm2d(self.transformer_dim), + nn.GELU(), + nn.ConvTranspose2d(self.transformer_dim, self.transformer_dim // 8, kernel_size=2, stride=2)) + + self.matting_embedding_encoder = nn.Sequential( + nn.ConvTranspose2d(self.transformer_dim, self.transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(self.transformer_dim // 4), + nn.GELU(), + nn.ConvTranspose2d(self.transformer_dim // 4, self.transformer_dim // 8, kernel_size=2, stride=2), + ) + self.matting_embedding_maskfeature = nn.Sequential( + nn.Conv2d(self.transformer_dim // 8, self.transformer_dim // 4, 3, 1, 1), + LayerNorm2d(self.transformer_dim // 4), + nn.GELU(), + nn.Conv2d(self.transformer_dim // 4, self.transformer_dim // 8, 3, 1, 1)) + + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + interm_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the ViT image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT + if not self.wo_hq: + hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features) + else: + hq_features = None + matting_hq_features = self.matting_embedding_encoder(image_embeddings) + self.matting_compress_vit_feat(vit_features) + + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + hq_features=hq_features, + matting_hq_features=matting_hq_features + ) + + # Select the correct mask or masks for output + if multimask_output: + # mask with highest score + if not self.wo_hq: + mask_slice = slice(1,self.num_mask_tokens - (self.matting_token_num + 1)) # matting_token_num + hq_token_num + else: + mask_slice = slice(1,self.num_mask_tokens - self.matting_token_num) # matting_token_num + iou_pred = iou_pred[:, mask_slice] + iou_pred, max_iou_idx = torch.max(iou_pred,dim=1) + iou_pred = iou_pred.unsqueeze(1) + masks_multi = masks[:, mask_slice, :, :] + masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1) + else: + # singale mask output, default + mask_slice = slice(0, 1) + iou_pred = iou_pred[:,mask_slice] + masks_sam = masks[:,mask_slice] + + if not self.wo_hq: + masks_hq = masks[:,slice(self.num_mask_tokens - (self.matting_token_num + 1), self.num_mask_tokens - self.matting_token_num)] + masks_matting = masks[:,slice(self.num_mask_tokens - self.matting_token_num, self.num_mask_tokens)] + + if not self.wo_hq: + if self.hq_token_only: + # masks_hq += masks_sam + masks_matting += masks_hq + else: + masks_hq += masks_sam + masks_matting += masks_hq + else: + masks_hq = masks_sam + if self.mask_matting_res_add: + masks_matting = masks_sam + masks_matting + else: + masks_matting = masks_matting + # Prepare output + return {'masks_sam': masks_sam, 'masks_hq': masks_hq, 'masks_matting': masks_matting} + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + hq_features: torch.Tensor, + matting_hq_features: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + if not self.wo_hq: + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight, self.matting_hf_token.weight], dim=0) + else: + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.matting_hf_token.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + + upscaled_embedding_sam = self.output_upscaling(src) + if not self.wo_hq: + upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(b,1,1,1) + upscaled_embedding_matting_hq = self.matting_embedding_maskfeature(upscaled_embedding_sam) + matting_hq_features.repeat(b,1,1,1) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + if i < self.num_mask_tokens - (self.matting_token_num + 1) or (self.wo_hq and i < self.num_mask_tokens - self.matting_token_num): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + elif i == self.num_mask_tokens - (self.matting_token_num + 1): + hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :])) + else: + hyper_in_list.append(self.matting_hf_mlp(mask_tokens_out[:, i, :])) + + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding_sam.shape + + if not self.wo_hq: + masks_sam = (hyper_in[:,:self.num_mask_tokens - (self.matting_token_num + 1)] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) + masks_sam_hq = (hyper_in[:,self.num_mask_tokens - (self.matting_token_num + 1) : self.num_mask_tokens - self.matting_token_num] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w) + else: + masks_sam = (hyper_in[:,:self.num_mask_tokens - self.matting_token_num] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) + masks_sam_matting_hq = (hyper_in[:, self.num_mask_tokens - self.matting_token_num:] @ upscaled_embedding_matting_hq.view(b, c, h * w)).view(b, -1, h, w) + + if not self.wo_hq: + masks = torch.cat([masks_sam, masks_sam_hq, masks_sam_matting_hq],dim=1) + else: + masks = torch.cat([masks_sam, masks_sam_matting_hq],dim=1) + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred diff --git a/modeling/semantic_enhanced_matting/modeling/prompt_encoder.py b/modeling/semantic_enhanced_matting/modeling/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c3143f4f8e02ddd7ca8587b40ff5d47c3a6b7ef3 --- /dev/null +++ b/modeling/semantic_enhanced_matting/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/modeling/semantic_enhanced_matting/modeling/sam.py b/modeling/semantic_enhanced_matting/modeling/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..dfca5c7b72253ab390036d565aeed7a19405a40d --- /dev/null +++ b/modeling/semantic_enhanced_matting/modeling/sam.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + hq_token_only: bool =False, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings, interm_embeddings = self.image_encoder(input_images) + interm_embeddings = interm_embeddings[0] # early layer + + outputs = [] + for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + hq_token_only=hq_token_only, + interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0), + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs, + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/modeling/semantic_enhanced_matting/modeling/tiny_vit_sam.py b/modeling/semantic_enhanced_matting/modeling/tiny_vit_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..65f04aa374599f6bb70fe69c81660df9d4e786e1 --- /dev/null +++ b/modeling/semantic_enhanced_matting/modeling/tiny_vit_sam.py @@ -0,0 +1,724 @@ +# -------------------------------------------------------- +# TinyViT Model Architecture +# Copyright (c) 2022 Microsoft +# Adapted from LeViT and Swin Transformer +# LeViT: (https://github.com/facebookresearch/levit) +# Swin: (https://github.com/microsoft/swin-transformer) +# Build the TinyViT Model +# -------------------------------------------------------- + +import itertools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath as TimmDropPath,\ + to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from typing import Tuple + + +class Conv2d_BN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class DropPath(TimmDropPath): + def __init__(self, drop_prob=None): + super().__init__(drop_prob=drop_prob) + self.drop_prob = drop_prob + + def __repr__(self): + msg = super().__repr__() + msg += f'(drop_prob={self.drop_prob})' + return msg + + +class PatchEmbed(nn.Module): + def __init__(self, in_chans, embed_dim, resolution, activation): + super().__init__() + img_size: Tuple[int, int] = to_2tuple(resolution) + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * \ + self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + n = embed_dim + self.seq = nn.Sequential( + Conv2d_BN(in_chans, n // 2, 3, 2, 1), + activation(), + Conv2d_BN(n // 2, n, 3, 2, 1), + ) + + def forward(self, x): + return self.seq(x) + + +class MBConv(nn.Module): + def __init__(self, in_chans, out_chans, expand_ratio, + activation, drop_path): + super().__init__() + self.in_chans = in_chans + self.hidden_chans = int(in_chans * expand_ratio) + self.out_chans = out_chans + + self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) + self.act1 = activation() + + self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, + ks=3, stride=1, pad=1, groups=self.hidden_chans) + self.act2 = activation() + + self.conv3 = Conv2d_BN( + self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) + self.act3 = activation() + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.act2(x) + + x = self.conv3(x) + + x = self.drop_path(x) + + x += shortcut + x = self.act3(x) + + return x + + +class PatchMerging(nn.Module): + def __init__(self, input_resolution, dim, out_dim, activation): + super().__init__() + + self.input_resolution = input_resolution + self.dim = dim + self.out_dim = out_dim + self.act = activation() + self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) + stride_c=2 + if(out_dim==320 or out_dim==448 or out_dim==576): + stride_c=1 + self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) + self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + if x.ndim == 3: + H, W = self.input_resolution + B = len(x) + # (B, C, H, W) + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class ConvLayer(nn.Module): + def __init__(self, dim, input_resolution, depth, + activation, + drop_path=0., downsample=None, use_checkpoint=False, + out_dim=None, + conv_expand_ratio=4., + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + MBConv(dim, dim, conv_expand_ratio, activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = nn.LayerNorm(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(torch.nn.Module): + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + super().__init__() + # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list(itertools.product( + range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N), + persistent=False) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.register_buffer('ab', + self.attention_biases[:, self.attention_bias_idxs], + persistent=False) + + def forward(self, x): # x (B,N,C) + B, N, _ = x.shape + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, - + 1).split([self.key_dim, self.key_dim, self.d], dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ( + (q @ k.transpose(-2, -1)) * self.scale + + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + ) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class TinyViTBlock(nn.Module): + r""" TinyViT Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int, int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + local_conv_size (int): the kernel size of the convolution between + Attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, + mlp_ratio=4., drop=0., drop_path=0., + local_conv_size=3, + activation=nn.GELU, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + assert window_size > 0, 'window_size must be greater than 0' + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention(dim, head_dim, num_heads, + attn_ratio=1, resolution=window_resolution) + + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_activation = activation + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=mlp_activation, drop=drop) + + pad = local_conv_size // 2 + self.local_conv = Conv2d_BN( + dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + res_x = x + if H == self.window_size and W == self.window_size: + x = self.attn(x) + else: + x = x.view(B, H, W, C) + pad_b = (self.window_size - H % + self.window_size) % self.window_size + pad_r = (self.window_size - W % + self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + # window partition + x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( + B * nH * nW, self.window_size * self.window_size, C) + x = self.attn(x) + # window reverse + x = x.view(B, nH, nW, self.window_size, self.window_size, + C).transpose(2, 3).reshape(B, pH, pW, C) + + if padding: + x = x[:, :H, :W].contiguous() + + x = x.view(B, L, C) + + x = res_x + self.drop_path(x) + + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.local_conv(x) + x = x.view(B, C, L).transpose(1, 2) + + x = x + self.drop_path(self.mlp(x)) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + + +class BasicLayer(nn.Module): + """ A basic TinyViT layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + out_dim: the output dimension of the layer. Default: dim + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., drop=0., + drop_path=0., downsample=None, use_checkpoint=False, + local_conv_size=3, + activation=nn.GELU, + out_dim=None, + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + TinyViTBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance( + drop_path, list) else drop_path, + local_conv_size=local_conv_size, + activation=activation, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x +class TinyViT(nn.Module): + def __init__(self, img_size=224, in_chans=3, num_classes=1000, + embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + ): + super().__init__() + self.img_size=img_size + self.num_classes = num_classes + self.depths = depths + self.num_layers = len(depths) + self.mlp_ratio = mlp_ratio + + activation = nn.GELU + + self.patch_embed = PatchEmbed(in_chans=in_chans, + embed_dim=embed_dims[0], + resolution=img_size, + activation=activation) + + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, + sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + kwargs = dict(dim=embed_dims[i_layer], + input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))), + # input_resolution=(patches_resolution[0] // (2 ** i_layer), + # patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + downsample=PatchMerging if ( + i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + out_dim=embed_dims[min( + i_layer + 1, len(embed_dims) - 1)], + activation=activation, + ) + if i_layer == 0: + layer = ConvLayer( + conv_expand_ratio=mbconv_expand_ratio, + **kwargs, + ) + else: + layer = BasicLayer( + num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs) + self.layers.append(layer) + + # Classifier head + self.norm_head = nn.LayerNorm(embed_dims[-1]) + self.head = nn.Linear( + embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + # init weights + self.apply(self._init_weights) + self.set_layer_lr_decay(layer_lr_decay) + self.neck = nn.Sequential( + nn.Conv2d( + embed_dims[-1], + 256, + kernel_size=1, + bias=False, + ), + LayerNorm2d(256), + nn.Conv2d( + 256, + 256, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(256), + ) + def set_layer_lr_decay(self, layer_lr_decay): + decay_rate = layer_lr_decay + + # layers -> blocks (depth) + depth = sum(self.depths) + lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] + #print("LR SCALES:", lr_scales) + + def _set_lr_scale(m, scale): + for p in m.parameters(): + p.lr_scale = scale + + self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) + i = 0 + for layer in self.layers: + for block in layer.blocks: + block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) + i += 1 + if layer.downsample is not None: + layer.downsample.apply( + lambda x: _set_lr_scale(x, lr_scales[i - 1])) + assert i == depth + for m in [self.norm_head, self.head]: + m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) + + for k, p in self.named_parameters(): + p.param_name = k + + def _check_lr_scale(m): + for p in m.parameters(): + assert hasattr(p, 'lr_scale'), p.param_name + + self.apply(_check_lr_scale) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'attention_biases'} + + def forward_features(self, x): + # x: (N, C, H, W) + x = self.patch_embed(x) + + x = self.layers[0](x) + start_i = 1 + + interm_embeddings=[] + for i in range(start_i, len(self.layers)): + layer = self.layers[i] + x = layer(x) + # print('x shape:', x.shape, '---i:', i) + if i == 1: + interm_embeddings.append(x.view(x.shape[0], 64, 64, -1)) + + B,_,C=x.size() + x = x.view(B, 64, 64, C) + x=x.permute(0, 3, 1, 2) + x=self.neck(x) + return x, interm_embeddings + + def forward(self, x): + x, interm_embeddings = self.forward_features(x) + #x = self.norm_head(x) + #x = self.head(x) + # print('come to here is correct'* 3) + return x, interm_embeddings + + +_checkpoint_url_format = \ + 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth' +_provided_checkpoints = { + 'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill', + 'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill', + 'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill', + 'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill', + 'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill', +} + + +def register_tiny_vit_model(fn): + '''Register a TinyViT model + It is a wrapper of `register_model` with loading the pretrained checkpoint. + ''' + def fn_wrapper(pretrained=False, **kwargs): + model = fn() + if pretrained: + model_name = fn.__name__ + assert model_name in _provided_checkpoints, \ + f'Sorry that the checkpoint `{model_name}` is not provided yet.' + url = _checkpoint_url_format.format( + _provided_checkpoints[model_name]) + checkpoint = torch.hub.load_state_dict_from_url( + url=url, + map_location='cpu', check_hash=False, + ) + model.load_state_dict(checkpoint['model']) + + return model + + # rename the name of fn_wrapper + fn_wrapper.__name__ = fn.__name__ + return register_model(fn_wrapper) + + +@register_tiny_vit_model +def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 256, 448], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 8, 14], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2): + return TinyViT( + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=384, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[12, 12, 24, 12], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=512, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[16, 16, 32, 16], + drop_path_rate=drop_path_rate, + ) diff --git a/modeling/semantic_enhanced_matting/modeling/transformer.py b/modeling/semantic_enhanced_matting/modeling/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..28fafea52288603fea275f3a100790471825c34a --- /dev/null +++ b/modeling/semantic_enhanced_matting/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/modeling/semantic_enhanced_matting/predictor.py b/modeling/semantic_enhanced_matting/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..73d078d03b6190985d8fa611231efa6844a6d247 --- /dev/null +++ b/modeling/semantic_enhanced_matting/predictor.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from .modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size if hasattr(sam_model.image_encoder, 'img_size') else sam_model.image_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + # import pdb;pdb.set_trace() + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + # import pdb;pdb.set_trace() + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features, self.interm_features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + hq_token_only: bool =False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + hq_token_only=hq_token_only, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + hq_token_only: bool =False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + hq_token_only=hq_token_only, + interm_embeddings=self.interm_features, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/modeling/semantic_enhanced_matting/utils/__init__.py b/modeling/semantic_enhanced_matting/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/modeling/semantic_enhanced_matting/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/modeling/semantic_enhanced_matting/utils/__pycache__/__init__.cpython-38.pyc b/modeling/semantic_enhanced_matting/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f79164a21c51300b8d96a5a74c38600cb3fb4305 Binary files /dev/null and b/modeling/semantic_enhanced_matting/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/utils/__pycache__/amg.cpython-38.pyc b/modeling/semantic_enhanced_matting/utils/__pycache__/amg.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..583eb6bdabf07e41d55ce5df361ccaa702a47d70 Binary files /dev/null and b/modeling/semantic_enhanced_matting/utils/__pycache__/amg.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/utils/__pycache__/transforms.cpython-38.pyc b/modeling/semantic_enhanced_matting/utils/__pycache__/transforms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..084df8b376efbe2fac8deca61d316bf8730cc6d5 Binary files /dev/null and b/modeling/semantic_enhanced_matting/utils/__pycache__/transforms.cpython-38.pyc differ diff --git a/modeling/semantic_enhanced_matting/utils/amg.py b/modeling/semantic_enhanced_matting/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..be064071ef399fea96c673ad173689656c23534a --- /dev/null +++ b/modeling/semantic_enhanced_matting/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/modeling/semantic_enhanced_matting/utils/onnx.py b/modeling/semantic_enhanced_matting/utils/onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..8013dc43d0373f1d84cd7ff7950822ff12b82a82 --- /dev/null +++ b/modeling/semantic_enhanced_matting/utils/onnx.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + hq_token_only: bool = False, + multimask_output: bool = False, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.hq_token_only = hq_token_only + self.multimask_output = multimask_output + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + interm_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT + hq_features = self.model.mask_decoder.embedding_encoder(image_embeddings) + self.model.mask_decoder.compress_vit_feat(vit_features) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + hq_features=hq_features, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.multimask_output: + # mask with highest score + mask_slice = slice(1,self.model.mask_decoder.num_mask_tokens-1) + scores = scores[:, mask_slice] + scores, max_iou_idx = torch.max(scores,dim=1) + scores = scores.unsqueeze(1) + masks_multi = masks[:, mask_slice, :, :] + masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1) + else: + # singale mask output, default + mask_slice = slice(0, 1) + scores = scores[:,mask_slice] + masks_sam = masks[:,mask_slice] + + masks_hq = masks[:,slice(self.model.mask_decoder.num_mask_tokens-1, self.model.mask_decoder.num_mask_tokens)] + + if self.hq_token_only: + masks = masks_hq + else: + masks = masks_sam + masks_hq + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/modeling/semantic_enhanced_matting/utils/transforms.py b/modeling/semantic_enhanced_matting/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c08ba1e3db751f3a5483a003be38c69c2cf2df85 --- /dev/null +++ b/modeling/semantic_enhanced_matting/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/pretrained/preprocess.py b/pretrained/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..b28412b4329757ed8acb0dc5978539a0e4fc4645 --- /dev/null +++ b/pretrained/preprocess.py @@ -0,0 +1,27 @@ +import torch +import wget + +def preprocess(model, name='dino', embed_dim=384): + new_model = {} + for k in model.keys(): + if 'patch_embed.proj.weight' in k: + x = torch.zeros(embed_dim, 4, 16, 16) + x[:, :3] = model[k] + new_model['backbone.'+k] = x + else: + new_model['backbone.'+k] = model[k] + if embed_dim==384: + size='s' + else: + size='b' + torch.save(new_model, name+'_vit_'+ size + '_fna.pth') + +if __name__ == "__main__": + + wget.download('https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth') + wget.download('https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth') + + dino_model = torch.load('dino_deitsmall16_pretrain.pth') + mae_model = torch.load('mae_pretrain_vit_base.pth')['model'] + preprocess(dino_model, 'dino', 384) + preprocess(mae_model, 'mae', 768) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4847f250554f6dd66d546a7719b24d6705094a42 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +torch==2.0.0 +torchvision +tensorboard +timm==0.5.4 +opencv-python==4.5.3.56 +setuptools==58.2.0 +easydict +wget +scikit-image +fairscale +imgaug +peft +kornia +gradio==4.44.1 +gradio_image_prompter +huggingface_hub +detectron2 @ git+https://github.com/facebookresearch/detectron2@v0.6 diff --git a/sam2/__init__.py b/sam2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff90d1042707d3190277036f3fd8e0ff177fd365 --- /dev/null +++ b/sam2/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from hydra import initialize_config_module + +initialize_config_module("sam2_configs", version_base="1.2") diff --git a/sam2/__pycache__/__init__.cpython-38.pyc b/sam2/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95dd7344f8c877202887c8f36677189d99419e57 Binary files /dev/null and b/sam2/__pycache__/__init__.cpython-38.pyc differ diff --git a/sam2/__pycache__/build_sam.cpython-38.pyc b/sam2/__pycache__/build_sam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de2c2814ddcc56f027d4e49dfba22793f8e64229 Binary files /dev/null and b/sam2/__pycache__/build_sam.cpython-38.pyc differ diff --git a/sam2/__pycache__/sam2_image_predictor.cpython-38.pyc b/sam2/__pycache__/sam2_image_predictor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d44530cce89b105284fcc04952805cc49123b39 Binary files /dev/null and b/sam2/__pycache__/sam2_image_predictor.cpython-38.pyc differ diff --git a/sam2/automatic_mask_generator.py b/sam2/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..67668b2fc937010d237f8195d5c059c7cc481a3e --- /dev/null +++ b/sam2/automatic_mask_generator.py @@ -0,0 +1,434 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from sam2.modeling.sam2_base import SAM2Base +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2.utils.amg import ( + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + MaskData, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SAM2AutomaticMaskGenerator: + def __init__( + self, + model: SAM2Base, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.8, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + mask_threshold: float = 0.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + use_m2m: bool = False, + multimask_output: bool = True, + ) -> None: + """ + Using a SAM 2 model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM 2 with a HieraL backbone. + + Arguments: + model (Sam): The SAM 2 model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + mask_threshold (float): Threshold for binarizing the mask logits + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + use_m2m (bool): Whether to add a one step refinement using previous mask predictions. + multimask_output (bool): Whether to output multimask at each point of the grid. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + try: + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + except ImportError as e: + print("Please install pycocotools") + raise e + + self.predictor = SAM2ImagePredictor( + model, + max_hole_area=min_mask_region_area, + max_sprinkle_area=min_mask_region_area, + ) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.mask_threshold = mask_threshold + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + self.use_m2m = use_m2m + self.multimask_output = multimask_output + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [ + coco_encode_rle(rle) for rle in mask_data["rles"] + ] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch( + points, cropped_im_size, crop_box, orig_size, normalize=True + ) + data.cat(batch_data) + del batch_data + self.predictor.reset_predictor() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + normalize=False, + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + points = torch.as_tensor(points, device=self.predictor.device) + in_points = self.predictor._transforms.transform_coords( + points, normalize=normalize, orig_hw=im_size + ) + in_labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, iou_preds, low_res_masks = self.predictor._predict( + in_points[:, None, :], + in_labels[:, None], + multimask_output=self.multimask_output, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=points.repeat_interleave(masks.shape[1], dim=0), + low_res_masks=low_res_masks.flatten(0, 1), + ) + del masks + + if not self.use_m2m: + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate and filter by stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + else: + # One step refinement using previous mask predictions + in_points = self.predictor._transforms.transform_coords( + data["points"], normalize=normalize, orig_hw=im_size + ) + labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, ious = self.refine_with_m2m( + in_points, labels, data["low_res_masks"], self.points_per_batch + ) + data["masks"] = masks.squeeze(1) + data["iou_preds"] = ious.squeeze(1) + + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge( + data["boxes"], crop_box, [0, 0, orig_w, orig_h] + ) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data + + def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch): + new_masks = [] + new_iou_preds = [] + + for cur_points, cur_point_labels, low_res_mask in batch_iterator( + points_per_batch, points, point_labels, low_res_masks + ): + best_masks, best_iou_preds, _ = self.predictor._predict( + cur_points[:, None, :], + cur_point_labels[:, None], + mask_input=low_res_mask[:, None, :], + multimask_output=False, + return_logits=True, + ) + new_masks.append(best_masks) + new_iou_preds.append(best_iou_preds) + masks = torch.cat(new_masks, dim=0) + return masks, torch.cat(new_iou_preds, dim=0) diff --git a/sam2/build_sam.py b/sam2/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..ad12041ee9d689a05a7c43149c183ffe84993336 --- /dev/null +++ b/sam2/build_sam.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf + + +def build_sam2( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + bbox_mask_matting_token = False, + matting_logits_res_add = False, + upscaled_embedding_res_add = True, +): + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + ] + + if bbox_mask_matting_token: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + "++model.mask_decoder_matting_token=true", + "++model.image_encoder.trunk._target_=sam2.modeling.backbones.hieradet.HieraBBoxMask", + "++model.matting_logits_res_add=true" if matting_logits_res_add else "++model.matting_logits_res_add=false", + "++model.upscaled_embedding_res_add=true" if upscaled_embedding_res_add else "++model.upscaled_embedding_res_add=false", + ] + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path, add_new_layer_weights=True) + + model = model.to(device) + if mode == "eval": + model.eval() + + if bbox_mask_matting_token: + for n, p in model.named_parameters(): + if 'matting' in n or 'bbox_mask' in n: + p.requires_grad = True + else: + p.requires_grad = False + + return model + + +def build_sam2_video_predictor( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, +): + hydra_overrides = [ + "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def build_sam2_hf(model_id, **kwargs): + + from huggingface_hub import hf_hub_download + + model_id_to_filenames = { + "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), + "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), + "facebook/sam2-hiera-base-plus": ( + "sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), + } + config_name, checkpoint_name = model_id_to_filenames[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) + + +def build_sam2_video_predictor_hf(model_id, **kwargs): + + from huggingface_hub import hf_hub_download + + model_id_to_filenames = { + "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), + "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), + "facebook/sam2-hiera-base-plus": ( + "sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), + } + config_name, checkpoint_name = model_id_to_filenames[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return build_sam2_video_predictor( + config_file=config_name, ckpt_path=ckpt_path, **kwargs + ) + + +def _load_checkpoint(model, ckpt_path, add_new_layer_weights=False): + # if add_new_layer_weights: + # assert ckpt_path is not None + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location="cpu")["model"] + if add_new_layer_weights: + + # bbox patch embed + sd['image_encoder.trunk.bbox_mask_patch_embed.proj.weight'] = torch.concat(( + sd['image_encoder.trunk.patch_embed.proj.weight'], + torch.mean(sd['image_encoder.trunk.patch_embed.proj.weight'], dim=1, keepdim=True) + ), dim=1) + sd['image_encoder.trunk.bbox_mask_patch_embed.proj.bias'] = sd['image_encoder.trunk.patch_embed.proj.bias'] + + # matting token + sd['sam_mask_decoder.matting_mask_tokens.weight'] = torch.mean(sd['sam_mask_decoder.mask_tokens.weight'], dim=0, keepdim=True).repeat(model.sam_mask_decoder.matting_token_num, 1) + + output_hypernetworks_mlps_0_keys = [key for key in sd.keys() if 'output_hypernetworks_mlps.0' in key] + for i in range(model.sam_mask_decoder.matting_token_num): + for key in output_hypernetworks_mlps_0_keys: + target_key = key.replace('output_hypernetworks_mlps.0', 'matting_output_hypernetworks_mlps.{}'.format(i)) + sd[target_key] = sd[key] + + output_upscaling_keys = [key for key in sd.keys() if 'output_upscaling' in key] + for key in output_upscaling_keys: + target_key = key.replace('output_upscaling', 'matting_output_upscaling') + sd[target_key] = sd[key] + + missing_keys, unexpected_keys = model.load_state_dict(sd) + if missing_keys: + logging.error(missing_keys) + raise RuntimeError() + if unexpected_keys: + logging.error(unexpected_keys) + raise RuntimeError() + logging.info("Loaded checkpoint sucessfully") diff --git a/sam2/csrc/connected_components.cu b/sam2/csrc/connected_components.cu new file mode 100644 index 0000000000000000000000000000000000000000..ced21eb32eaaadb818d441c1322b99d1bf068f45 --- /dev/null +++ b/sam2/csrc/connected_components.cu @@ -0,0 +1,289 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. + +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// adapted from https://github.com/zsef123/Connected_components_PyTorch +// with license found in the LICENSE_cctorch file in the root directory. +#include +#include +#include +#include +#include +#include + +// 2d +#define BLOCK_ROWS 16 +#define BLOCK_COLS 16 + +namespace cc2d { + +template +__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { + return (bitmap >> pos) & 1; +} + +__device__ int32_t find(const int32_t* s_buf, int32_t n) { + while (s_buf[n] != n) + n = s_buf[n]; + return n; +} + +__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { + const int32_t id = n; + while (s_buf[n] != n) { + n = s_buf[n]; + s_buf[id] = n; + } + return n; +} + +__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { + bool done; + do { + a = find(s_buf, a); + b = find(s_buf, b); + + if (a < b) { + int32_t old = atomicMin(s_buf + b, a); + done = (old == b); + b = old; + } else if (b < a) { + int32_t old = atomicMin(s_buf + a, b); + done = (old == a); + a = old; + } else + done = true; + + } while (!done); +} + +__global__ void +init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row < H && col < W) + label[idx] = idx; +} + +__global__ void +merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + uint32_t P = 0; + + if (img[idx]) + P |= 0x777; + if (row + 1 < H && img[idx + W]) + P |= 0x777 << 4; + if (col + 1 < W && img[idx + 1]) + P |= 0x777 << 1; + + if (col == 0) + P &= 0xEEEE; + if (col + 1 >= W) + P &= 0x3333; + else if (col + 2 >= W) + P &= 0x7777; + + if (row == 0) + P &= 0xFFF0; + if (row + 1 >= H) + P &= 0xFF; + + if (P > 0) { + // If need check about top-left pixel(if flag the first bit) and hit the + // top-left pixel + if (hasBit(P, 0) && img[idx - W - 1]) { + union_(label, idx, idx - 2 * W - 2); // top left block + } + + if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) + union_(label, idx, idx - 2 * W); // top bottom block + + if (hasBit(P, 3) && img[idx + 2 - W]) + union_(label, idx, idx - 2 * W + 2); // top right block + + if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) + union_(label, idx, idx - 2); // just left block + } +} + +__global__ void compression(int32_t* label, const int32_t W, const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row < H && col < W) + find_n_compress(label, idx); +} + +__global__ void final_labeling( + const uint8_t* img, + int32_t* label, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx] + 1; + + if (img[idx]) + label[idx] = y; + else + label[idx] = 0; + + if (col + 1 < W) { + if (img[idx + 1]) + label[idx + 1] = y; + else + label[idx + 1] = 0; + + if (row + 1 < H) { + if (img[idx + W + 1]) + label[idx + W + 1] = y; + else + label[idx + W + 1] = 0; + } + } + + if (row + 1 < H) { + if (img[idx + W]) + label[idx + W] = y; + else + label[idx + W] = 0; + } +} + +__global__ void init_counting( + const int32_t* label, + int32_t* count_init, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx]; + if (y > 0) { + int32_t count_idx = y - 1; + atomicAdd(count_init + count_idx, 1); + } +} + +__global__ void final_counting( + const int32_t* label, + const int32_t* count_init, + int32_t* count_final, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx]; + if (y > 0) { + int32_t count_idx = y - 1; + count_final[idx] = count_init[count_idx]; + } else { + count_final[idx] = 0; + } +} + +} // namespace cc2d + +std::vector get_connected_componnets( + const torch::Tensor& inputs) { + AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); + AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); + AT_ASSERTM( + inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); + + const uint32_t N = inputs.size(0); + const uint32_t C = inputs.size(1); + const uint32_t H = inputs.size(2); + const uint32_t W = inputs.size(3); + + AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); + AT_ASSERTM((H % 2) == 0, "height must be an even number"); + AT_ASSERTM((W % 2) == 0, "width must be an even number"); + + // label must be uint32_t + auto label_options = + torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); + torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); + torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); + torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); + + dim3 grid = dim3( + ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, + ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); + dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); + dim3 grid_count = + dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); + dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + for (int n = 0; n < N; n++) { + uint32_t offset = n * H * W; + + cc2d::init_labeling<<>>( + labels.data_ptr() + offset, W, H); + cc2d::merge<<>>( + inputs.data_ptr() + offset, + labels.data_ptr() + offset, + W, + H); + cc2d::compression<<>>( + labels.data_ptr() + offset, W, H); + cc2d::final_labeling<<>>( + inputs.data_ptr() + offset, + labels.data_ptr() + offset, + W, + H); + + // get the counting of each pixel + cc2d::init_counting<<>>( + labels.data_ptr() + offset, + counts_init.data_ptr() + offset, + W, + H); + cc2d::final_counting<<>>( + labels.data_ptr() + offset, + counts_init.data_ptr() + offset, + counts_final.data_ptr() + offset, + W, + H); + } + + // returned values are [labels, counts] + std::vector outputs; + outputs.push_back(labels); + outputs.push_back(counts_final); + return outputs; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "get_connected_componnets", + &get_connected_componnets, + "get_connected_componnets"); +} diff --git a/sam2/modeling/__init__.py b/sam2/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/sam2/modeling/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/modeling/__pycache__/__init__.cpython-38.pyc b/sam2/modeling/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d5d8ad6db0de1c412842a6180a9102a6710f3c6 Binary files /dev/null and b/sam2/modeling/__pycache__/__init__.cpython-38.pyc differ diff --git a/sam2/modeling/__pycache__/memory_attention.cpython-38.pyc b/sam2/modeling/__pycache__/memory_attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b326cce83183de1df2f78510bd84c3b3e282bf97 Binary files /dev/null and b/sam2/modeling/__pycache__/memory_attention.cpython-38.pyc differ diff --git a/sam2/modeling/__pycache__/memory_encoder.cpython-38.pyc b/sam2/modeling/__pycache__/memory_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8dfda2877a8c75024f2df773d5e8edcc8ed844c Binary files /dev/null and b/sam2/modeling/__pycache__/memory_encoder.cpython-38.pyc differ diff --git a/sam2/modeling/__pycache__/position_encoding.cpython-38.pyc b/sam2/modeling/__pycache__/position_encoding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12347e45deec1f2d3157a9302d1126a482710598 Binary files /dev/null and b/sam2/modeling/__pycache__/position_encoding.cpython-38.pyc differ diff --git a/sam2/modeling/__pycache__/sam2_base.cpython-38.pyc b/sam2/modeling/__pycache__/sam2_base.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f07fe8f41df4ca9f93b8b0bfb35f82103e71fbbe Binary files /dev/null and b/sam2/modeling/__pycache__/sam2_base.cpython-38.pyc differ diff --git a/sam2/modeling/__pycache__/sam2_utils.cpython-38.pyc b/sam2/modeling/__pycache__/sam2_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ddc6ab7a35cf4b9e2ad30d5375c9cd926b8c91e Binary files /dev/null and b/sam2/modeling/__pycache__/sam2_utils.cpython-38.pyc differ diff --git a/sam2/modeling/backbones/__init__.py b/sam2/modeling/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/sam2/modeling/backbones/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/modeling/backbones/__pycache__/__init__.cpython-38.pyc b/sam2/modeling/backbones/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..010a69e49cf524fdf0dedb354fadaa2807dddd04 Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/__init__.cpython-38.pyc differ diff --git a/sam2/modeling/backbones/__pycache__/hieradet.cpython-38.pyc b/sam2/modeling/backbones/__pycache__/hieradet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a7a555aeaf7364acd19e6fc95c81630d90f7251 Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/hieradet.cpython-38.pyc differ diff --git a/sam2/modeling/backbones/__pycache__/image_encoder.cpython-38.pyc b/sam2/modeling/backbones/__pycache__/image_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4933839c75088fd93b7083ad4f5b004546921b1 Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/image_encoder.cpython-38.pyc differ diff --git a/sam2/modeling/backbones/__pycache__/utils.cpython-38.pyc b/sam2/modeling/backbones/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..424099ce2487c1c3778225acda64f768bbaac87e Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/utils.cpython-38.pyc differ diff --git a/sam2/modeling/backbones/hieradet.py b/sam2/modeling/backbones/hieradet.py new file mode 100644 index 0000000000000000000000000000000000000000..690041c6f70d967d79b5c4e84b1fd97e54b2e242 --- /dev/null +++ b/sam2/modeling/backbones/hieradet.py @@ -0,0 +1,339 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) + +from sam2.modeling.sam2_utils import DropPath, MLP + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + + self.num_heads = num_heads + head_dim = dim_out // num_heads + self.scale = head_dim**-0.5 + + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d( + kernel_size=q_stride, stride=q_stride, ceil_mode=False + ) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + activation=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Hiera(nn.Module): + """ + Reference: https://arxiv.org/abs/2306.00989 + """ + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + return_interm_layers=True, # return feats from every stage + ): + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) + ) + self.pos_embed_window = nn.Parameter( + torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) + ) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile( + [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] + ) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or ( + i in self.stage_ends and self.return_interm_layers + ): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs + + +class HieraBBoxMask(Hiera): + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.bbox_mask_patch_embed = PatchEmbed( + in_chans=4, + embed_dim=self.patch_embed.proj.out_channels, + ) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + + # x = self.patch_embed(x) + + img, condition = x[0], x[1] + if condition is not None: + # concat mask and img as condition + bbox_mask = torch.zeros_like(img)[:, 0:1] + for i in range(condition.shape[0]): + l, u, r, d = condition[i, 0, :] + bbox_mask[i, :, int(u): int(d), int(l): int(r)] = 1.0 + condition_input = torch.concat((img, bbox_mask), dim=1) + x = self.patch_embed(img) + self.bbox_mask_patch_embed(condition_input) + else: + x = self.patch_embed(img) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or ( + i in self.stage_ends and self.return_interm_layers + ): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs \ No newline at end of file diff --git a/sam2/modeling/backbones/image_encoder.py b/sam2/modeling/backbones/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5f92baf47dcab96385ff99899fd3e3a642c1cf9c --- /dev/null +++ b/sam2/modeling/backbones/image_encoder.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ImageEncoder(nn.Module): + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" + + def forward(self, sample: torch.Tensor): + # Forward through backbone + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(nn.Module): + """ + A modified variant of Feature Pyramid Network (FPN) neck + (we remove output conv and also do bicubic interpolation similar to ViT + pos embed interpolation) + """ + + def __init__( + self, + position_encoding: nn.Module, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """Initialize the neck + :param trunk: the backbone + :param position_encoding: the positional encoding to use + :param d_model: the dimension of the model + :param neck_norm: the normalization to use + """ + super().__init__() + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=( + None if self.fpn_interp_model == "nearest" else False + ), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos diff --git a/sam2/modeling/backbones/utils.py b/sam2/modeling/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32d55c7545f064de133a5ff0200ba1ece9b504b7 --- /dev/null +++ b/sam2/modeling/backbones/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some utilities for backbones, in particular for windowing""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, ...] = (7, 7), + stride: Tuple[int, ...] = (4, 4), + padding: Tuple[int, ...] = (3, 3), + in_chans: int = 3, + embed_dim: int = 768, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/sam2/modeling/memory_attention.py b/sam2/modeling/memory_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..0b07f9d87e3d8194ca5e11fc20f01604d591a59d --- /dev/null +++ b/sam2/modeling/memory_attention.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn, Tensor + +from sam2.modeling.sam.transformer import RoPEAttention + +from sam2.modeling.sam2_utils import get_activation_fn, get_clones + + +class MemoryAttentionLayer(nn.Module): + + def __init__( + self, + activation: str, + cross_attention: nn.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + self_attention: nn.Module, + ): + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention + self.cross_attn_image = cross_attention + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + + # Self-Attn, Cross-Attn + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + assert ( + curr.shape[1] == memory.shape[1] + ), "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/sam2/modeling/memory_encoder.py b/sam2/modeling/memory_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f60202dfaba87232c3870fb2101b5322a119d985 --- /dev/null +++ b/sam2/modeling/memory_encoder.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d + + +class MaskDownSampler(nn.Module): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + return self.encoder(x) + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class CXBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + def __init__(self, layer, num_layers, dim=None, input_projection=False): + super().__init__() + self.proj = nn.Identity() + self.layers = get_clones(layer, num_layers) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + # normally x: (N, C, H, W) + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class MemoryEncoder(nn.Module): + def __init__( + self, + out_dim, + mask_downsampler, + fuser, + position_encoding, + in_dim=256, # in_dim of pix_feats + ): + super().__init__() + + self.mask_downsampler = mask_downsampler + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = fuser + self.position_encoding = position_encoding + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ## Process masks + # sigmoid, so that less domain shift from gt masks which are bool + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + ## Fuse pix_feats and downsampled masks + # in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b57ae7b431f3859af1368acdf4597d671cda32 --- /dev/null +++ b/sam2/modeling/position_encoding.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple + +import numpy as np + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + # The positions are expected to be normalized + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack( + (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 + ).flatten(1) + pos_y = torch.stack( + (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 + ).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def forward(self, x: torch.Tensor): + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +# Rotary Positional Encoding, adapted from: +# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +# 2. https://github.com/naver-ai/rope-vit +# 3. https://github.com/lucidrains/rotary-embedding-torch + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/sam2/modeling/sam/__init__.py b/sam2/modeling/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/sam2/modeling/sam/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/modeling/sam/__pycache__/__init__.cpython-38.pyc b/sam2/modeling/sam/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04d07d5dfdb24c5405223f23e93e4fb4498d37e6 Binary files /dev/null and b/sam2/modeling/sam/__pycache__/__init__.cpython-38.pyc differ diff --git a/sam2/modeling/sam/__pycache__/mask_decoder.cpython-38.pyc b/sam2/modeling/sam/__pycache__/mask_decoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2af91dbdace3f54d6ffd5e14eccd7a265f0abb4 Binary files /dev/null and b/sam2/modeling/sam/__pycache__/mask_decoder.cpython-38.pyc differ diff --git a/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-38.pyc b/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6418f0562c84ba4a72a8841b6e9d010c06496a63 Binary files /dev/null and b/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-38.pyc differ diff --git a/sam2/modeling/sam/__pycache__/transformer.cpython-38.pyc b/sam2/modeling/sam/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebc5538fc8f07a878d4fac3f698729506518b2f6 Binary files /dev/null and b/sam2/modeling/sam/__pycache__/transformer.cpython-38.pyc differ diff --git a/sam2/modeling/sam/mask_decoder.py b/sam2/modeling/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..775e05895572e884043f4e1c2e72cd6205302661 --- /dev/null +++ b/sam2/modeling/sam/mask_decoder.py @@ -0,0 +1,458 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.sam2_utils import LayerNorm2d, MLP + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d( + transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 + ) + self.conv_s1 = nn.Conv2d( + transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 + ) + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + torch.Tensor: batched SAM token for mask output + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange( + multimask_iou_scores.size(0), device=all_iou_scores.device + ) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out + + +from copy import deepcopy + +class MaskDecoderMattingToken(MaskDecoder): + def __init__( + self, + matting_token_num = 3, + upscaled_embedding_res_add = True, + matting_logits_res_add = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.matting_token_num = matting_token_num + self.upscaled_embedding_res_add = upscaled_embedding_res_add + self.matting_logits_res_add = matting_logits_res_add + + self.num_mask_tokens = self.num_mask_tokens + self.matting_token_num + self.matting_mask_tokens = nn.Embedding(self.matting_token_num, self.transformer_dim) + self.matting_output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(self.transformer_dim, self.transformer_dim, self.transformer_dim // 8, 3) + for i in range(self.matting_token_num) + ] + ) + self.matting_output_upscaling = deepcopy(self.output_upscaling) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + wo_matting_token: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + sam2_logits, matting_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + wo_matting_token=wo_matting_token, + ) + sam2_logits = sam2_logits[:, 0:1, :, :] + if self.matting_logits_res_add and matting_logits is not None: + matting_logits = matting_logits + sam2_logits + + return sam2_logits, matting_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + wo_matting_token: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + self.matting_mask_tokens.weight + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight, self.matting_mask_tokens.weight], dim=0 + ) + + if wo_matting_token: + output_tokens = output_tokens[:-self.matting_token_num] + + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + # if not self.use_high_res_features:comparison + # upscaled_embedding = self.output_upscaling(src) + # else: + assert self.use_high_res_features + # ori process + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + if not wo_matting_token: + # matting process + matting_dc1, matting_ln1, matting_act1, matting_dc2, matting_act2 = self.matting_output_upscaling + matting_upscaled_embedding = matting_act1(matting_ln1(matting_dc1(src) + feat_s1)) + matting_upscaled_embedding = matting_act2(matting_dc2(matting_upscaled_embedding) + feat_s0) + if self.upscaled_embedding_res_add: + matting_upscaled_embedding = upscaled_embedding + matting_upscaled_embedding # use res form + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens - self.matting_token_num if wo_matting_token else self.num_mask_tokens): + if i < self.num_mask_tokens - self.matting_token_num: + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + # matting token + else: + hyper_in_list.append( + self.matting_output_hypernetworks_mlps[i - (self.num_mask_tokens - self.matting_token_num)](mask_tokens_out[:, i, :]) + ) + + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + + # ori token + sam2_logits = (hyper_in[:, :self.num_mask_tokens - self.matting_token_num] @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + if not wo_matting_token: + # matting token + matting_logits = (hyper_in[:, -self.matting_token_num: ] @ matting_upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + else: + matting_logits = None + # Generate mask quality predictions + # iou_pred = self.iou_prediction_head(iou_token_out) + # if self.pred_obj_scores: + # assert s == 1 + # object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + # else: + # # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + # object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return sam2_logits, matting_logits \ No newline at end of file diff --git a/sam2/modeling/sam/prompt_encoder.py b/sam2/modeling/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6b3bbb95be0aea9c88f49f586ac959a9fda1b18b --- /dev/null +++ b/sam2/modeling/sam/prompt_encoder.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.position_encoding import PositionEmbeddingRandom + +from sam2.modeling.sam2_utils import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + point_embedding[labels == 2] += self.point_embeddings[2].weight + point_embedding[labels == 3] += self.point_embeddings[3].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b6fa2f87e85a7f222fb2ba0b661734dc57a08a --- /dev/null +++ b/sam2/modeling/sam/transformer.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import math +import warnings +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis +from sam2.modeling.sam2_utils import MLP +from sam2.utils.misc import get_sdpa_settings + +warnings.simplefilter(action="ignore", category=FutureWarning) +# Check whether Flash Attention is available (and use it by default) +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() +# A fallback setting to allow all available kernels if Flash Attention fails +ALLOW_ALL_KERNELS = False + + +def sdp_kernel_context(dropout_p): + """ + Get the context for the attention scaled dot-product kernel. We use Flash Attention + by default, but fall back to all available kernels if Flash Attention fails. + """ + if ALLOW_ALL_KERNELS: + return contextlib.nullcontext() + + return torch.backends.cuda.sdp_kernel( + enable_flash=USE_FLASH_ATTN, + # if Flash attention kernel is off, then math kernel needs to be enabled + enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + enable_mem_efficient=OLD_GPU, + ) + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation + ) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py new file mode 100644 index 0000000000000000000000000000000000000000..b79524c80f1453b96361c9ffeede6ad08fe87bf0 --- /dev/null +++ b/sam2/modeling/sam2_base.py @@ -0,0 +1,864 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed +import torch.nn.functional as F + +from torch.nn.init import trunc_normal_ + +from sam2.modeling.sam.mask_decoder import MaskDecoder, MaskDecoderMattingToken +from sam2.modeling.sam.prompt_encoder import PromptEncoder +from sam2.modeling.sam.transformer import TwoWayTransformer +from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAM2Base(torch.nn.Module): + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, # default 1 input frame + 6 previous frames + image_size=512, + backbone_stride=16, # stride of the image backbone output + sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob + sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed=False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder=False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder=16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs=True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs=False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval=False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + # Soft no object, i.e. mix in no_obj_ptr softly, + # hope to make recovery easier if there is a mistake and mitigate accumulation of errors + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + # xrh defined + mask_decoder_matting_token: bool = False, + matting_logits_res_add: bool = False, + upscaled_embedding_res_add: bool = True, + ): + super().__init__() + + # xrh + self.mask_decoder_matting_token = mask_decoder_matting_token + self.matting_logits_res_add = matting_logits_res_add + self.upscaled_embedding_res_add = upscaled_embedding_res_add + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + self.hidden_dim = memory_attention.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr( + self.memory_encoder.out_proj, "weight" + ): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter( + torch.zeros(num_maskmem, 1, 1, self.mem_dim) + ) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + + self._build_sam_heads() + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print( + "Image encoder compilation is enabled. First forward pass will be slow." + ) + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference." + "See notebooks/video_predictor_example.ipynb for an example." + ) + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + + # xrh + if self.mask_decoder_matting_token: + self.sam_mask_decoder = MaskDecoderMattingToken( + # xrh + upscaled_embedding_res_add=self.upscaled_embedding_res_add, + matting_logits_res_add=self.matting_logits_res_add, + + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + else: + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP( + self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 + ) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + # Only hard possible with gt + assert not self.teacher_force_obj_scores_for_mem + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros( + mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device + ) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def forward_image(self, img_batch: torch.Tensor): + """Get the image feature on the input batch.""" + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with r>1), in which case + # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. + r = self.memory_temporal_stride_for_eval + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * r + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * r + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].cuda(non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = ( + maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + ) + to_cat_memory_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + (abs(frame_idx - t), out["obj_ptr"]) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get( + t, unselected_cond_outputs.get(t, None) + ) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape( + -1, B, C // self.mem_dim, self.mem_dim + ) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + # Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder) + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + is_mask_from_pts, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + return maskmem_features, maskmem_pos_enc + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + _, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/sam2/modeling/sam2_utils.py b/sam2/modeling/sam2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9705963efc57d74b7d1bff31692d7d293a46ad --- /dev/null +++ b/sam2/modeling/sam2_utils.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DropPath(nn.Module): + # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + def __init__(self, drop_prob=0.0, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..14843fec0f791be2572d69dd8020d10e0b0669f2 --- /dev/null +++ b/sam2/sam2_image_predictor.py @@ -0,0 +1,517 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL.Image import Image + +from sam2.modeling.sam2_base import SAM2Base + +from sam2.utils.transforms import SAM2Transforms + + +class SAM2ImagePredictor: + def __init__( + self, + sam_model: SAM2Base, + mask_threshold=0.0, + max_hole_area=0.0, + max_sprinkle_area=0.0, + ) -> None: + """ + Uses SAM-2 to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam-2): The model to use for mask prediction. + mask_threshold (float): The threshold to use when converting mask logits + to binary masks. Masks are thresholded at 0 by default. + fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to + the maximum area of fill_hole_area in low_res_masks. + """ + super().__init__() + self.model = sam_model + self._transforms = SAM2Transforms( + resolution=self.model.image_size, + mask_threshold=mask_threshold, + max_hole_area=max_hole_area, + max_sprinkle_area=max_sprinkle_area, + ) + + # Predictor state + self._is_image_set = False + self._features = None + self._orig_hw = None + # Whether the predictor is set for single image or a batch of images + self._is_batch = False + + # Predictor config + self.mask_threshold = mask_threshold + + # Spatial dim for backbone feature maps + self._bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2ImagePredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model) + + @torch.no_grad() + def set_image( + self, + image: Union[np.ndarray, Image], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image + with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + self.reset_predictor() + # Transform the image to the form expected by the model + if isinstance(image, np.ndarray): + logging.info("For numpy array image, we assume (HxWxC) format") + self._orig_hw = [image.shape[:2]] + elif isinstance(image, Image): + w, h = image.size + self._orig_hw = [(h, w)] + else: + raise NotImplementedError("Image format not supported") + + input_image = self._transforms(image) + input_image = input_image[None, ...].to(self.device) + + assert ( + len(input_image.shape) == 4 and input_image.shape[1] == 3 + ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" + logging.info("Computing image embeddings for the provided image...") + backbone_out = self.model.forward_image(input_image) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + logging.info("Image embeddings computed.") + + @torch.no_grad() + def set_image_batch( + self, + image_list: List[Union[np.ndarray]], + ) -> None: + """ + Calculates the image embeddings for the provided image batch, allowing + masks to be predicted with the 'predict_batch' method. + + Arguments: + image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray + with pixel values in [0, 255]. + """ + self.reset_predictor() + assert isinstance(image_list, list) + self._orig_hw = [] + for image in image_list: + assert isinstance( + image, np.ndarray + ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" + self._orig_hw.append(image.shape[:2]) + # Transform the image to the form expected by the model + img_batch = self._transforms.forward_batch(image_list) + img_batch = img_batch.to(self.device) + batch_size = img_batch.shape[0] + assert ( + len(img_batch.shape) == 4 and img_batch.shape[1] == 3 + ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" + logging.info("Computing image embeddings for the provided images...") + backbone_out = self.model.forward_image(img_batch) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + self._is_batch = True + logging.info("Image embeddings computed.") + + def predict_batch( + self, + point_coords_batch: List[np.ndarray] = None, + point_labels_batch: List[np.ndarray] = None, + box_batch: List[np.ndarray] = None, + mask_input_batch: List[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. + It returns a tupele of lists of masks, ious, and low_res_masks_logits. + """ + assert self._is_batch, "This function should only be used when in batched mode" + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image_batch(...) before mask prediction." + ) + num_images = len(self._features["image_embed"]) + all_masks = [] + all_ious = [] + all_low_res_masks = [] + for img_idx in range(num_images): + # Transform input prompts + point_coords = ( + point_coords_batch[img_idx] if point_coords_batch is not None else None + ) + point_labels = ( + point_labels_batch[img_idx] if point_labels_batch is not None else None + ) + box = box_batch[img_idx] if box_batch is not None else None + mask_input = ( + mask_input_batch[img_idx] if mask_input_batch is not None else None + ) + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, + point_labels, + box, + mask_input, + normalize_coords, + img_idx=img_idx, + ) + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + img_idx=img_idx, + ) + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = ( + iou_predictions.squeeze(0).float().detach().cpu().numpy() + ) + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + all_masks.append(masks_np) + all_ious.append(iou_predictions_np) + all_low_res_masks.append(low_res_masks_np) + + return all_masks, all_ious, all_low_res_masks + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + # Transform input prompts + + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, point_labels, box, mask_input, normalize_coords + ) + + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy() + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + def _prep_prompts( + self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1 + ): + + unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + unnorm_coords = self._transforms.transform_coords( + point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) + labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + if len(unnorm_coords.shape) == 2: + unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...] + if box is not None: + box = torch.as_tensor(box, dtype=torch.float, device=self.device) + unnorm_box = self._transforms.transform_boxes( + box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) # Bx2x2 + if mask_logits is not None: + mask_input = torch.as_tensor( + mask_logits, dtype=torch.float, device=self.device + ) + if len(mask_input.shape) == 3: + mask_input = mask_input[None, :, :, :] + return mask_input, unnorm_coords, labels, unnorm_box + + @torch.no_grad() + def _predict( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + img_idx: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using SAM2Transforms. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + concat_points = (point_coords, point_labels) + else: + concat_points = None + + # Embed prompts + if boxes is not None: + box_coords = boxes.reshape(-1, 2, 2) + box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device) + box_labels = box_labels.repeat(boxes.size(0), 1) + # we merge "boxes" and "points" into a single "concat_points" input (where + # boxes are added at the beginning) to sam_prompt_encoder + if concat_points is not None: + concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) + concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) + concat_points = (concat_coords, concat_labels) + else: + concat_points = (box_coords, box_labels) + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=concat_points, + boxes=None, + masks=mask_input, + ) + + # Predict masks + batched_mode = ( + concat_points is not None and concat_points[0].shape[0] > 1 + ) # multi object prediction + high_res_features = [ + feat_level[img_idx].unsqueeze(0) + for feat_level in self._features["high_res_feats"] + ] + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( + image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + + # Upscale the masks to the original image resolution + masks = self._transforms.postprocess_masks( + low_res_masks, self._orig_hw[img_idx] + ) + low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) + if not return_logits: + masks = masks > self.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert ( + self._features is not None + ), "Features must exist if an image has been set." + return self._features["image_embed"] + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_predictor(self) -> None: + """ + Resets the image embeddings and other state variables. + """ + self._is_image_set = False + self._features = None + self._orig_hw = None + self._is_batch = False + + def predict_batch_boxes_and_features( + self, + boxes, + features, + multimask_output: bool = False, + return_logits: bool = True, + wo_matting_token: bool = False, + ): + + box_coords = boxes.reshape(-1, 2, 2) + box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device) + box_labels = box_labels.repeat(boxes.size(0), 1) + # we merge "boxes" and "points" into a single "concat_points" input (where + # boxes are added at the beginning) to sam_prompt_encoder + concat_points = (box_coords, box_labels) + + image_pe=self.model.sam_prompt_encoder.get_dense_pe() + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=concat_points, + boxes=None, + masks=None, + ) + + batch_size = boxes.shape[0] + cat_sam2_logits = [] + cat_matting_logits = [] + for bs_idx in range(batch_size): + + high_res_features = [ + feat_level[bs_idx].unsqueeze(0) + for feat_level in features["high_res_feats"] + ] + + sam2_logits, matting_logits = self.model.sam_mask_decoder( + image_embeddings=features["image_embed"][bs_idx].unsqueeze(0), + image_pe=image_pe, + sparse_prompt_embeddings=sparse_embeddings[bs_idx].unsqueeze(0), + dense_prompt_embeddings=dense_embeddings[bs_idx].unsqueeze(0), + multimask_output=multimask_output, + repeat_image=False, + high_res_features=high_res_features, + wo_matting_token=wo_matting_token, + ) + cat_sam2_logits.append(sam2_logits) + if not wo_matting_token: + cat_matting_logits.append(matting_logits) + + sam2_logits = torch.concatenate(cat_sam2_logits, dim=0) + if not wo_matting_token: + matting_logits = torch.concatenate(cat_matting_logits, dim=0) + return sam2_logits, matting_logits + else: + return sam2_logits \ No newline at end of file diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a6bdf4b5742c3214e655657abd4b6bb9b7ede2 --- /dev/null +++ b/sam2/sam2_video_predictor.py @@ -0,0 +1,957 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from collections import OrderedDict + +import torch + +from tqdm import tqdm + +from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames + + +class SAM2VideoPredictor(SAM2Base): + """The predictor class to handle user interactions and manage inference states.""" + + def __init__( + self, + fill_hole_area=0, + # whether to apply non-overlapping constraints on the output object masks + non_overlap_masks=False, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). + clear_non_cond_mem_for_multi_obj=False, + **kwargs, + ): + super().__init__(**kwargs) + self.fill_hole_area = fill_hole_area + self.non_overlap_masks = non_overlap_masks + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + + @torch.inference_mode() + def init_state( + self, + video_path, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + ): + """Initialize a inference state.""" + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = torch.device("cuda") + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = torch.device("cuda") + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2VideoPredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_video_predictor_hf + + sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) + return cls(sam_model) + + def _obj_id_to_idx(self, inference_state, obj_id): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + return len(inference_state["obj_idx_to_id"]) + + @torch.inference_mode() + def add_new_points_or_box( + self, + inference_state, + frame_idx, + obj_id, + points=None, + labels=None, + clear_old_points=True, + normalize_coords=True, + box=None, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if (points is not None) != (labels is not None): + raise ValueError("points and labels must be provided together") + if points is None and box is None: + raise ValueError("at least one of points or box must be provided as input") + + if points is None: + points = torch.zeros(0, 2, dtype=torch.float32) + elif not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32) + if labels is None: + labels = torch.zeros(0, dtype=torch.int32) + elif not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + + # If `box` is provided, we add it as the first two points with labels 2 and 3 + # along with the user-provided points (consistent with how SAM 2 is trained). + if box is not None: + if not clear_old_points: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + if inference_state["tracking_has_started"]: + warnings.warn( + "You are adding a box after tracking starts. SAM 2 may not always be " + "able to incorporate a box prompt for *refinement*. If you intend to " + "use box prompt as an *initial* input before tracking, please call " + "'reset_state' on the inference state to restart from scratch.", + category=UserWarning, + stacklevel=2, + ) + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.float32, device=points.device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) + box_labels = box_labels.reshape(1, 2) + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def add_new_points(self, *args, **kwargs): + """Deprecated method. Please use `add_new_points_or_box` instead.""" + return self.add_new_points_or_box(*args, **kwargs) + + @torch.inference_mode() + def add_new_mask( + self, + inference_state, + frame_idx, + obj_id, + mask, + ): + """Add new mask to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask, dtype=torch.bool) + assert mask.dim() == 2 + mask_H, mask_W = mask.shape + mask_inputs_orig = mask[None, None] # add batch and channel dimension + mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) + + # resize the mask if it doesn't match the model's image size + if mask_H != self.image_size or mask_W != self.image_size: + mask_inputs = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(self.image_size, self.image_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + mask_inputs = (mask_inputs >= 0.5).float() + else: + mask_inputs = mask_inputs_orig + + mask_inputs_per_frame[frame_idx] = mask_inputs + point_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + run_mem_encoder, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + assert not run_mem_encoder, "memory encoder cannot run at video resolution" + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.image_size // 4 + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + consolidated_mask_key: torch.full( + size=(batch_size, 1, consolidated_H, consolidated_W), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ), + "obj_ptr": torch.full( + size=(batch_size, self.hidden_dim), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["device"], + ), + } + empty_mask_ptr = None + for obj_idx in range(batch_size): + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + if empty_mask_ptr is None: + empty_mask_ptr = self._get_empty_mask_ptr( + inference_state, frame_idx + ) + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr + continue + # Add the temporary object output mask to consolidated output mask + obj_mask = out["pred_masks"] + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + + # Optionally, apply non-overlapping constraints on the consolidated scores + # and rerun the memory encoder + if run_mem_encoder: + device = inference_state["device"] + high_res_masks = torch.nn.functional.interpolate( + consolidated_out["pred_masks"].to(device, non_blocking=True), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks_for_mem_enc: + high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=batch_size, + high_res_masks=high_res_masks, + is_mask_from_pts=True, # these frames are what the user interacted with + ) + consolidated_out["maskmem_features"] = maskmem_features + consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc + + return consolidated_out + + def _get_empty_mask_ptr(self, inference_state, frame_idx): + """Get a dummy object pointer based on an empty mask on the current frame.""" + # A dummy (empty) mask with a single object + batch_size = 1 + mask_inputs = torch.zeros( + (batch_size, 1, self.image_size, self.image_size), + dtype=torch.float32, + device=inference_state["device"], + ) + + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + mask_inputs=mask_inputs, + output_dict={}, + num_frames=inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + # Tracking has started and we don't allow adding new objects until session is reset. + inference_state["tracking_has_started"] = True + batch_size = self._get_obj_num(inference_state) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outptus + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points_or_box` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temprary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object( + inference_state, frame_idx, consolidated_out, storage_key + ) + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] + | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + """Propagate the input points across frames to track in the entire video.""" + self.propagate_in_video_preflight(inference_state) + + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min(output_dict["cond_frame_outputs"]) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + output_dict[storage_key][frame_idx] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + yield frame_idx, obj_ids, video_res_masks + + def _add_output_per_object( + self, inference_state, frame_idx, current_out, storage_key + ): + """ + Split a multi-object output into per-object output slices and add them into + `output_dict_per_obj`. The resulting slices share the same tensor storage. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + output_dict_per_obj = inference_state["output_dict_per_obj"] + for obj_idx, obj_output_dict in output_dict_per_obj.items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + @torch.inference_mode() + def reset_state(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute the image features on a given frame.""" + # Look up in the cache first + image, backbone_out = inference_state["cached_features"].get( + frame_idx, (None, None) + ) + if backbone_out is None: + # Cache miss -- we will run inference on a single image + image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0) + backbone_out = self.forward_image(image) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + # expand the features to have the same dimension as the number of objects + expanded_image = image.expand(batch_size, -1, -1, -1) + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + expanded_backbone_out["backbone_fpn"][i] = feat.expand( + batch_size, -1, -1, -1 + ) + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + pos = pos.expand(batch_size, -1, -1, -1) + expanded_backbone_out["vision_pos_enc"][i] = pos + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features + return features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + pred_masks_gpu = fill_holes_in_mask_scores( + pred_masks_gpu, self.fill_hole_area + ) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + } + return compact_current_out, pred_masks_gpu + + def _run_memory_encoder( + self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts + ): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + is_mask_from_pts=is_mask_from_pts, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc( + inference_state, {"maskmem_pos_enc": maskmem_pos_enc} + ) + return maskmem_features, maskmem_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out["maskmem_pos_enc"] + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [ + x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc + ] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This method clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + output_dict = inference_state["output_dict"] + non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/sam2/utils/__init__.py b/sam2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/sam2/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2/utils/__pycache__/__init__.cpython-38.pyc b/sam2/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b21e2ff1656de771fe65e29da0d2acf066754878 Binary files /dev/null and b/sam2/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/sam2/utils/__pycache__/misc.cpython-38.pyc b/sam2/utils/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ce20911a73e11523e49706ef94ea22a8b3767ba Binary files /dev/null and b/sam2/utils/__pycache__/misc.cpython-38.pyc differ diff --git a/sam2/utils/__pycache__/transforms.cpython-38.pyc b/sam2/utils/__pycache__/transforms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eec9b72dae7c7cb5d30d8d3bf329af70df64c79 Binary files /dev/null and b/sam2/utils/__pycache__/transforms.cpython-38.pyc differ diff --git a/sam2/utils/amg.py b/sam2/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..986842960cf5deca00614b7b1cde1ab77dad7e6e --- /dev/null +++ b/sam2/utils/amg.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + +import numpy as np +import torch + +# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.float().detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..df97b4a8e96f968d8993473344bf60eb8fadfd65 --- /dev/null +++ b/sam2/utils/misc.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] boxes, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self._images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.cuda(non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError("Only JPEG frames are supported at this moment") + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, image_size, offload_video_to_cpu, img_mean, img_std + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.cuda() + img_mean = img_mean.cuda() + img_std = img_std.cuda() + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. " + "Consider building SAM 2 with CUDA extension to enable post-processing (see " + "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} diff --git a/sam2/utils/transforms.py b/sam2/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..995baf989da5a8e4927c87b1bbb0777067b673cc --- /dev/null +++ b/sam2/utils/transforms.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize, Resize, ToTensor + + +class SAM2Transforms(nn.Module): + def __init__( + self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 + ): + """ + Transforms for SAM2. + """ + super().__init__() + self.resolution = resolution + self.mask_threshold = mask_threshold + self.max_hole_area = max_hole_area + self.max_sprinkle_area = max_sprinkle_area + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + self.to_tensor = ToTensor() + self.transforms = torch.jit.script( + nn.Sequential( + Resize((self.resolution, self.resolution)), + Normalize(self.mean, self.std), + ) + ) + + def __call__(self, x): + x = self.to_tensor(x) + return self.transforms(x) + + def forward_batch(self, img_list): + img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] + img_batch = torch.stack(img_batch, dim=0) + return img_batch + + def transform_coords( + self, coords: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. + """ + if normalize: + assert orig_hw is not None + h, w = orig_hw + coords = coords.clone() + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + + coords = coords * self.resolution # unnormalize coords + return coords + + def transform_boxes( + self, boxes: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, + if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + """ + boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + return boxes + + def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + """ + Perform PostProcessing on output masks. + """ + from sam2.utils.misc import get_connected_components + + masks = masks.float() + input_masks = masks + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + try: + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + labels, areas = get_connected_components( + mask_flat <= self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + + if self.max_sprinkle_area > 0: + labels, areas = get_connected_components( + mask_flat > self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + except Exception as e: + # Skip the post-processing step if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. " + "Consider building SAM 2 with CUDA extension to enable post-processing (see " + "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + masks = input_masks + + masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) + return masks diff --git a/sam2_configs/__init__.py b/sam2_configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/sam2_configs/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/sam2_configs/__pycache__/__init__.cpython-38.pyc b/sam2_configs/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e29f3cd0b6e2646b4f089ce4b7273f0a93ebd3e0 Binary files /dev/null and b/sam2_configs/__pycache__/__init__.cpython-38.pyc differ diff --git a/sam2_configs/sam2_hiera_b+.yaml b/sam2_configs/sam2_hiera_b+.yaml new file mode 100644 index 0000000000000000000000000000000000000000..58f3eb81554018e873f8515ecb98e36d16ac29e4 --- /dev/null +++ b/sam2_configs/sam2_hiera_b+.yaml @@ -0,0 +1,113 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 112 + num_heads: 2 + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [896, 448, 224, 112] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2_configs/sam2_hiera_l.yaml b/sam2_configs/sam2_hiera_l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..918667f50c3e1ad2dcf77c0c14cb4dd114cfd080 --- /dev/null +++ b/sam2_configs/sam2_hiera_l.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 144 + num_heads: 2 + stages: [2, 6, 36, 4] + global_att_blocks: [23, 33, 43] + window_pos_embed_bkg_spatial_size: [7, 7] + window_spec: [8, 4, 16, 8] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [1152, 576, 288, 144] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2_configs/sam2_hiera_s.yaml b/sam2_configs/sam2_hiera_s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26e5d4d39f7b2892396106005c37c7ffe6c83bc2 --- /dev/null +++ b/sam2_configs/sam2_hiera_s.yaml @@ -0,0 +1,116 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 11, 2] + global_att_blocks: [7, 10, 13] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + compile_image_encoder: False diff --git a/sam2_configs/sam2_hiera_t.yaml b/sam2_configs/sam2_hiera_t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a62c903aaa5f80828077c6e06a59626926570ed6 --- /dev/null +++ b/sam2_configs/sam2_hiera_t.yaml @@ -0,0 +1,118 @@ +# @package _global_ + +# Model +model: + _target_: sam2.modeling.sam2_base.SAM2Base + image_encoder: + _target_: sam2.modeling.backbones.image_encoder.ImageEncoder + scalp: 1 + trunk: + _target_: sam2.modeling.backbones.hieradet.Hiera + embed_dim: 96 + num_heads: 1 + stages: [1, 2, 7, 2] + global_att_blocks: [5, 7, 9] + window_pos_embed_bkg_spatial_size: [7, 7] + neck: + _target_: sam2.modeling.backbones.image_encoder.FpnNeck + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 256 + normalize: true + scale: null + temperature: 10000 + d_model: 256 + backbone_channel_list: [768, 384, 192, 96] + fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features + fpn_interp_model: nearest + + memory_attention: + _target_: sam2.modeling.memory_attention.MemoryAttention + d_model: 256 + pos_enc_at_input: true + layer: + _target_: sam2.modeling.memory_attention.MemoryAttentionLayer + activation: relu + dim_feedforward: 2048 + dropout: 0.1 + pos_enc_at_attn: false + self_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + d_model: 256 + pos_enc_at_cross_attn_keys: true + pos_enc_at_cross_attn_queries: false + cross_attention: + _target_: sam2.modeling.sam.transformer.RoPEAttention + rope_theta: 10000.0 + feat_sizes: [32, 32] + rope_k_repeat: True + embedding_dim: 256 + num_heads: 1 + downsample_rate: 1 + dropout: 0.1 + kv_in_dim: 64 + num_layers: 4 + + memory_encoder: + _target_: sam2.modeling.memory_encoder.MemoryEncoder + out_dim: 64 + position_encoding: + _target_: sam2.modeling.position_encoding.PositionEmbeddingSine + num_pos_feats: 64 + normalize: true + scale: null + temperature: 10000 + mask_downsampler: + _target_: sam2.modeling.memory_encoder.MaskDownSampler + kernel_size: 3 + stride: 2 + padding: 1 + fuser: + _target_: sam2.modeling.memory_encoder.Fuser + layer: + _target_: sam2.modeling.memory_encoder.CXBlock + dim: 256 + kernel_size: 7 + padding: 3 + layer_scale_init_value: 1e-6 + use_dwconv: True # depth-wise convs + num_layers: 2 + + num_maskmem: 7 + image_size: 1024 + # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask + # SAM decoder + sigmoid_scale_for_mem_enc: 20.0 + sigmoid_bias_for_mem_enc: -10.0 + use_mask_input_as_output_without_sam: true + # Memory + directly_add_no_mem_embed: true + # use high-resolution feature map in the SAM mask decoder + use_high_res_features_in_sam: true + # output 3 masks on the first click on initial conditioning frames + multimask_output_in_sam: true + # SAM heads + iou_prediction_use_sigmoid: True + # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder: true + add_tpos_enc_to_obj_ptrs: false + only_obj_ptrs_in_the_past_for_eval: true + # object occlusion prediction + pred_obj_scores: true + pred_obj_scores_mlp: true + fixed_no_obj_ptr: true + # multimask tracking settings + multimask_output_for_tracking: true + use_multimask_token_for_obj_ptr: true + multimask_min_pt_num: 0 + multimask_max_pt_num: 1 + use_mlp_for_obj_ptr_proj: true + # Compilation flag + # HieraT does not currently support compilation, should always be set to False + compile_image_encoder: False