hanzi_YOLOV11 / load_model.py
AISkywalker's picture
Upload 101 files
3c3b0ed verified
from ultralytics import YOLO
import torch
import os
from torch import nn
def load_model(model_path: str, device: str = None):#这是返回yolo对象的模型
try:
# 检查模型路径是否存在
if not os.path.isfile(model_path):
raise FileNotFoundError(f"模型文件错误,请修改加载路径: {model_path}")
# 自动选择设备为GPU
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 验证设备是否可用
if device == 'cuda' and not torch.cuda.is_available():
print("警告: CUDA不可用,将使用CPU")
device = 'cpu'
# 加载模型
model = YOLO(model_path).to(device)
print(f"成功加载模型到 {device.upper()}")
return model
except Exception as e:
raise RuntimeError(f"加载模型失败: {str(e)}")
# --------------------------------------------------
# load_pytorch_module (返回 torch.nn.Module)
# --------------------------------------------------
def load_pytorch_module(model_path: str, device: str = None) -> nn.Module:
"""
加载 Ultralytics YOLO 模型文件,并返回底层的 PyTorch nn.Module。
这个函数专门用于需要直接获取 torch.nn.Module 实例的场景;例如,检查或特定集成。
Returns:
torch.nn.Module: 底层的 PyTorch 模型实例。
"""
try:
# 检查模型路径是否存在
if not os.path.isfile(model_path):
raise FileNotFoundError(f"模型文件错误,请修改加载路径: {model_path}")
# 自动选择设备为GPU
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 验证设备是否可用
if device == 'cuda' and not torch.cuda.is_available():
print("警告: CUDA不可用,将使用CPU")
device = 'cpu'
# 1. 先加载 YOLO 对象
print(f"使用 YOLO 加载器加载: {model_path}")
yolo_wrapper = YOLO(model_path)
print("YOLO 加载器加载成功。")
# 2. 从 YOLO 对象中提取底层的 PyTorch 模型
# 通常,这个模型存储在 YOLO 对象的 .model 属性中
print("正在提取底层的 torch.nn.Module...")
pytorch_model = yolo_wrapper.model
if not isinstance(pytorch_model, nn.Module):
# 做个健壮性检查,以防未来 Ultralytics 内部结构改变
raise TypeError(f"从YOLO对象提取的 '.model' 属性不是 torch.nn.Module 的实例,实际类型为 {type(pytorch_model)}")
print(f"成功提取 PyTorch 模型,类型: {type(pytorch_model).__name__}")
# 3. 将提取出的 PyTorch 模型移动到指定设备
print(f"正在将 PyTorch 模型移动到 {device.upper()}...")
pytorch_model.to(device)
print(f"PyTorch 模型成功移动到 {device.upper()}")
# 4. 返回这个底层的 PyTorch 模型
return pytorch_model
except Exception as e:
raise RuntimeError(f"加载底层 PyTorch 模型失败: {str(e)}")
if __name__ == "__main__":
model_path = "best_model.pt"
device = "cpu"
model = load_pytorch_module(model_path, device)
model_yolo=load_model(model_path, device)
print('原始的torch模型:',model)
print('加载yolo模型:',model_yolo)