|
from ultralytics import YOLO
|
|
import torch
|
|
import os
|
|
from torch import nn
|
|
def load_model(model_path: str, device: str = None):
|
|
try:
|
|
|
|
if not os.path.isfile(model_path):
|
|
raise FileNotFoundError(f"模型文件错误,请修改加载路径: {model_path}")
|
|
|
|
|
|
if device is None:
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
if device == 'cuda' and not torch.cuda.is_available():
|
|
print("警告: CUDA不可用,将使用CPU")
|
|
device = 'cpu'
|
|
|
|
|
|
model = YOLO(model_path).to(device)
|
|
print(f"成功加载模型到 {device.upper()}")
|
|
return model
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"加载模型失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_pytorch_module(model_path: str, device: str = None) -> nn.Module:
|
|
"""
|
|
加载 Ultralytics YOLO 模型文件,并返回底层的 PyTorch nn.Module。
|
|
这个函数专门用于需要直接获取 torch.nn.Module 实例的场景;例如,检查或特定集成。
|
|
|
|
Returns:
|
|
torch.nn.Module: 底层的 PyTorch 模型实例。
|
|
"""
|
|
try:
|
|
|
|
if not os.path.isfile(model_path):
|
|
raise FileNotFoundError(f"模型文件错误,请修改加载路径: {model_path}")
|
|
|
|
|
|
if device is None:
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
if device == 'cuda' and not torch.cuda.is_available():
|
|
print("警告: CUDA不可用,将使用CPU")
|
|
device = 'cpu'
|
|
|
|
|
|
print(f"使用 YOLO 加载器加载: {model_path}")
|
|
yolo_wrapper = YOLO(model_path)
|
|
print("YOLO 加载器加载成功。")
|
|
|
|
|
|
|
|
print("正在提取底层的 torch.nn.Module...")
|
|
pytorch_model = yolo_wrapper.model
|
|
if not isinstance(pytorch_model, nn.Module):
|
|
|
|
raise TypeError(f"从YOLO对象提取的 '.model' 属性不是 torch.nn.Module 的实例,实际类型为 {type(pytorch_model)}")
|
|
print(f"成功提取 PyTorch 模型,类型: {type(pytorch_model).__name__}")
|
|
|
|
|
|
print(f"正在将 PyTorch 模型移动到 {device.upper()}...")
|
|
pytorch_model.to(device)
|
|
print(f"PyTorch 模型成功移动到 {device.upper()}")
|
|
|
|
|
|
return pytorch_model
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"加载底层 PyTorch 模型失败: {str(e)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model_path = "best_model.pt"
|
|
device = "cpu"
|
|
model = load_pytorch_module(model_path, device)
|
|
model_yolo=load_model(model_path, device)
|
|
print('原始的torch模型:',model)
|
|
|
|
print('加载yolo模型:',model_yolo)
|
|
|