|
import gradio as gr |
|
from gradio_image_prompter import ImagePrompter |
|
from detectron2.config import LazyConfig, instantiate |
|
from detectron2.checkpoint import DetectionCheckpointer |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
model_choice = { |
|
'SAM': None, |
|
'HQ-SAM': None, |
|
'SAM2': None |
|
} |
|
|
|
for model_type in model_choice.keys(): |
|
model_choice[model_type] = hf_hub_download(repo_id="XiaRho/SEMat", filename=f"SEMat_{model_type}.pth", repo_type="model") |
|
|
|
def load_model(model_type='SAM2'): |
|
assert model_type in model_choice.keys() |
|
config_path = './configs/SEMat_{}.py'.format(model_type) |
|
cfg = LazyConfig.load(config_path) |
|
|
|
if hasattr(cfg.model.sam_model, 'ckpt_path'): |
|
cfg.model.sam_model.ckpt_path = None |
|
else: |
|
cfg.model.sam_model.checkpoint = None |
|
model = instantiate(cfg.model) |
|
if model.lora_rank is not None: |
|
model.init_lora() |
|
model.to(DEVICE) |
|
DetectionCheckpointer(model).load(model_choice[model_type]) |
|
model.eval() |
|
return model, model_type |
|
|
|
def transform_image_bbox(prompts): |
|
if len(prompts["points"]) != 1: |
|
raise gr.Error("Please input only one BBox.", duration=5) |
|
[[x1, y1, idx_3, x2, y2, idx_6]] = prompts["points"] |
|
if idx_3 != 2 or idx_6 != 3: |
|
raise gr.Error("Please input BBox instead of point.", duration=5) |
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) |
|
|
|
img = prompts["image"] |
|
ori_H, ori_W, _ = img.shape |
|
|
|
scale = 1024 * 1.0 / max(ori_H, ori_W) |
|
new_H, new_W = ori_H * scale, ori_W * scale |
|
new_W = int(new_W + 0.5) |
|
new_H = int(new_H + 0.5) |
|
|
|
img = cv2.resize(img, (new_W, new_H), interpolation=cv2.INTER_LINEAR) |
|
padding = np.zeros([1024, 1024, 3], dtype=img.dtype) |
|
padding[: new_H, : new_W, :] = img |
|
img = padding |
|
|
|
img = img.transpose((2, 0, 1)).astype(np.float32) / 255.0 |
|
|
|
[[x1, y1, _, x2, y2, _]] = prompts["points"] |
|
x1, y1, x2, y2 = int(x1 * scale + 0.5), int(y1 * scale + 0.5), int(x2 * scale + 0.5), int(y2 * scale + 0.5) |
|
bbox = np.clip(np.array([[x1, y1, x2, y2]]) * 1.0, 0, 1023.0) |
|
|
|
return img, bbox, (ori_H, ori_W), (new_H, new_W) |
|
|
|
if __name__ == '__main__': |
|
|
|
model, model_type = load_model() |
|
|
|
def inference_image(prompts, input_model_type): |
|
|
|
global model_type |
|
global model |
|
|
|
if input_model_type != model_type: |
|
gr.Info('Loading SEMat of {} version.'.format(input_model_type), duration=5) |
|
_model, _ = load_model(input_model_type) |
|
model_type = input_model_type |
|
model = _model |
|
|
|
image, bbox, ori_H_W, pad_H_W = transform_image_bbox(prompts) |
|
input_data = { |
|
'image': torch.from_numpy(image)[None].to(model.device), |
|
'bbox': torch.from_numpy(bbox)[None].to(model.device), |
|
} |
|
|
|
with torch.no_grad(): |
|
inputs = model.preprocess_inputs(input_data) |
|
images, bbox, gt_alpha, trimap, condition = inputs['images'], inputs['bbox'], inputs['alpha'], inputs['trimap'], inputs['condition'] |
|
|
|
if model.backbone_condition: |
|
condition_proj = model.condition_embedding(condition) |
|
elif model.backbone_bbox_prompt is not None or model.bbox_prompt_all_block is not None: |
|
condition_proj = bbox |
|
else: |
|
condition_proj = None |
|
|
|
low_res_masks, pred_alphas, pred_trimap, sam_hq_matting_token = model.forward_samhq_and_matting_decoder(images, bbox, condition_proj) |
|
|
|
|
|
output_alpha = np.uint8(pred_alphas[0, 0][:pad_H_W[0], :pad_H_W[1], None].repeat(1, 1, 3).cpu().numpy() * 255) |
|
|
|
return output_alpha |
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=45): |
|
img_in = ImagePrompter(type='numpy', show_label=False, label="query image") |
|
|
|
with gr.Column(scale=45): |
|
img_out = gr.Image(type='pil', label="output") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=45): |
|
input_model_type = gr.Dropdown(list(model_choice.keys()), value='SAM2', label="Trained SEMat Version") |
|
|
|
with gr.Column(scale=45): |
|
bt = gr.Button() |
|
|
|
bt.click(inference_image, inputs=[img_in, input_model_type], outputs=[img_out]) |
|
|
|
demo.launch() |
|
|
|
|