u2net-mvtec-loco-segmentation / download_from_hf.py
zhiqing0205
Add complete U2Net project with HuggingFace preparation
ece7754
#!/usr/bin/env python3
"""
HuggingFace model download utility for U2NET MVTec LOCO segmentation
"""
import os
from huggingface_hub import hf_hub_download, snapshot_download
from pathlib import Path
def download_u2net_model(model_path="./saved_models/u2net/u2net.pth",
repo_id="zhiqing0205/u2net-mvtec-loco-segmentation",
force_download=False):
"""
Download U2NET model from HuggingFace Hub
Args:
model_path: Local path to save the model
repo_id: HuggingFace repository ID
force_download: Force re-download even if file exists
"""
# Check if model already exists
if os.path.exists(model_path) and not force_download:
print(f"Model already exists at {model_path}")
return model_path
print(f"Downloading U2NET model from HuggingFace: {repo_id}")
try:
# Ensure directory exists
os.makedirs(os.path.dirname(model_path), exist_ok=True)
# Download specific model file
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename="saved_models/u2net/u2net.pth",
local_dir=".",
local_dir_use_symlinks=False
)
print(f"Model downloaded successfully to: {downloaded_path}")
return downloaded_path
except Exception as e:
print(f"Error downloading model: {e}")
print("Please download manually from:")
print("https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view")
return None
def download_complete_repo(local_dir="./u2net-mvtec-loco",
repo_id="zhiqing0205/u2net-mvtec-loco-segmentation"):
"""
Download complete repository from HuggingFace Hub
Args:
local_dir: Local directory to save the repo
repo_id: HuggingFace repository ID
"""
print(f"Downloading complete repository: {repo_id}")
try:
# Download entire repository
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
local_dir_use_symlinks=False
)
print(f"Repository downloaded successfully to: {local_dir}")
return local_dir
except Exception as e:
print(f"Error downloading repository: {e}")
return None
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Download U2NET model or complete repo from HuggingFace")
parser.add_argument("--model-only", action="store_true",
help="Download only the model file")
parser.add_argument("--complete-repo", action="store_true",
help="Download the complete repository")
parser.add_argument("--repo-id", type=str,
default="zhiqing0205/u2net-mvtec-loco-segmentation",
help="HuggingFace repository ID")
parser.add_argument("--force", action="store_true",
help="Force download even if files exist")
args = parser.parse_args()
if args.complete_repo:
download_complete_repo(repo_id=args.repo_id)
else:
download_u2net_model(repo_id=args.repo_id, force_download=args.force)