|
|
|
""" |
|
HuggingFace model download utility for U2NET MVTec LOCO segmentation |
|
""" |
|
|
|
import os |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
from pathlib import Path |
|
|
|
def download_u2net_model(model_path="./saved_models/u2net/u2net.pth", |
|
repo_id="zhiqing0205/u2net-mvtec-loco-segmentation", |
|
force_download=False): |
|
""" |
|
Download U2NET model from HuggingFace Hub |
|
|
|
Args: |
|
model_path: Local path to save the model |
|
repo_id: HuggingFace repository ID |
|
force_download: Force re-download even if file exists |
|
""" |
|
|
|
|
|
if os.path.exists(model_path) and not force_download: |
|
print(f"Model already exists at {model_path}") |
|
return model_path |
|
|
|
print(f"Downloading U2NET model from HuggingFace: {repo_id}") |
|
|
|
try: |
|
|
|
os.makedirs(os.path.dirname(model_path), exist_ok=True) |
|
|
|
|
|
downloaded_path = hf_hub_download( |
|
repo_id=repo_id, |
|
filename="saved_models/u2net/u2net.pth", |
|
local_dir=".", |
|
local_dir_use_symlinks=False |
|
) |
|
|
|
print(f"Model downloaded successfully to: {downloaded_path}") |
|
return downloaded_path |
|
|
|
except Exception as e: |
|
print(f"Error downloading model: {e}") |
|
print("Please download manually from:") |
|
print("https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view") |
|
return None |
|
|
|
def download_complete_repo(local_dir="./u2net-mvtec-loco", |
|
repo_id="zhiqing0205/u2net-mvtec-loco-segmentation"): |
|
""" |
|
Download complete repository from HuggingFace Hub |
|
|
|
Args: |
|
local_dir: Local directory to save the repo |
|
repo_id: HuggingFace repository ID |
|
""" |
|
|
|
print(f"Downloading complete repository: {repo_id}") |
|
|
|
try: |
|
|
|
snapshot_download( |
|
repo_id=repo_id, |
|
local_dir=local_dir, |
|
local_dir_use_symlinks=False |
|
) |
|
|
|
print(f"Repository downloaded successfully to: {local_dir}") |
|
return local_dir |
|
|
|
except Exception as e: |
|
print(f"Error downloading repository: {e}") |
|
return None |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="Download U2NET model or complete repo from HuggingFace") |
|
parser.add_argument("--model-only", action="store_true", |
|
help="Download only the model file") |
|
parser.add_argument("--complete-repo", action="store_true", |
|
help="Download the complete repository") |
|
parser.add_argument("--repo-id", type=str, |
|
default="zhiqing0205/u2net-mvtec-loco-segmentation", |
|
help="HuggingFace repository ID") |
|
parser.add_argument("--force", action="store_true", |
|
help="Force download even if files exist") |
|
|
|
args = parser.parse_args() |
|
|
|
if args.complete_repo: |
|
download_complete_repo(repo_id=args.repo_id) |
|
else: |
|
download_u2net_model(repo_id=args.repo_id, force_download=args.force) |