|
import os
|
|
import json
|
|
import shutil
|
|
import random
|
|
from glob import glob
|
|
from sklearn.model_selection import train_test_split
|
|
from tqdm import tqdm
|
|
|
|
|
|
data_root = r'D:\不会编程\CVPR\class_project\Data\data'
|
|
char_dict_path = r'D:\不会编程\CVPR\class_project\char_dict.json'
|
|
output_dir = r'D:\不会编程\CVPR\class_project\yolo_hanzi_dataset'
|
|
|
|
|
|
val_size = 0.15
|
|
test_size = 0.15
|
|
|
|
|
|
|
|
|
|
bbox_annotation = "0.5 0.5 1.0 1.0"
|
|
|
|
|
|
random_seed = 42
|
|
random.seed(random_seed)
|
|
|
|
|
|
print("加载 Char_dict.json...")
|
|
try:
|
|
with open(char_dict_path, 'r', encoding='utf-8') as f:
|
|
char_dict_raw = json.load(f)
|
|
except UnicodeDecodeError:
|
|
try:
|
|
with open(char_dict_path, 'r', encoding='ansi') as f:
|
|
char_dict_raw = json.load(f)
|
|
print("警告: Char_dict.json 使用 ANSI 编码读取,建议转换为 UTF-8。")
|
|
except Exception as e:
|
|
print(f"错误: 无法读取 Char_dict.json。请确保文件存在且编码正确 (UTF-8 或 ANSI)。错误信息: {e}")
|
|
exit()
|
|
except FileNotFoundError:
|
|
print(f"错误: 找不到 Char_dict.json 文件,路径: {char_dict_path}")
|
|
exit()
|
|
except json.JSONDecodeError as e:
|
|
print(f"错误: Char_dict.json 文件格式无效。错误信息: {e}")
|
|
exit()
|
|
|
|
|
|
|
|
|
|
|
|
print(f"从 JSON 文件加载了 {len(char_dict_raw)} 个原始映射条目。")
|
|
|
|
|
|
print(f"扫描数据目录 '{data_root}' 以确定实际存在的类别...")
|
|
actual_unicode_keys = set()
|
|
try:
|
|
data_folders = [d for d in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, d))]
|
|
print(f"找到 {len(data_folders)} 个文件夹。")
|
|
for folder_name in data_folders:
|
|
try:
|
|
|
|
key = str(int(folder_name))
|
|
if key in char_dict_raw:
|
|
actual_unicode_keys.add(key)
|
|
except ValueError:
|
|
|
|
pass
|
|
except FileNotFoundError:
|
|
print(f"错误:无法访问指定的 data_root 目录: {data_root}。")
|
|
exit()
|
|
except Exception as e:
|
|
print(f"扫描数据目录时发生错误: {e}")
|
|
exit()
|
|
|
|
if not actual_unicode_keys:
|
|
print(f"错误:在数据目录 '{data_root}' 中没有找到任何与 JSON 文件中的键匹配的文件夹。")
|
|
exit()
|
|
|
|
print(f"根据文件夹名称,确定了 {len(actual_unicode_keys)} 个实际存在的有效类别键。")
|
|
|
|
|
|
|
|
sorted_actual_keys = sorted(list(actual_unicode_keys), key=lambda x: int(x))
|
|
|
|
unicode_to_classid = {}
|
|
classid_to_char = {}
|
|
num_classes = 0
|
|
|
|
print("正在创建过滤后的类别映射...")
|
|
for i, unicode_key in enumerate(sorted_actual_keys):
|
|
if unicode_key in char_dict_raw:
|
|
|
|
original_char = char_dict_raw[unicode_key]
|
|
cleaned_char = original_char.replace('\u0000', '')
|
|
|
|
|
|
if not cleaned_char:
|
|
print(f"警告:键 '{unicode_key}' 对应的字符为空,将跳过此类别。")
|
|
continue
|
|
|
|
class_id = num_classes
|
|
unicode_to_classid[unicode_key] = class_id
|
|
classid_to_char[class_id] = cleaned_char
|
|
num_classes += 1
|
|
else:
|
|
|
|
print(f"内部错误:键 '{unicode_key}' 未在 char_dict_raw 中找到,尽管它在 actual_unicode_keys 中。")
|
|
|
|
if num_classes == 0:
|
|
print("错误:过滤后没有有效的类别。请检查文件夹名称、JSON内容和脚本逻辑。")
|
|
exit()
|
|
|
|
print(f"最终确定使用 {num_classes} 个有效类别进行处理。")
|
|
|
|
print("有效类别映射 (部分示例):")
|
|
|
|
example_keys = list(unicode_to_classid.keys())
|
|
for i in range(min(5, num_classes)):
|
|
key = example_keys[i]
|
|
class_id = unicode_to_classid[key]
|
|
char = classid_to_char[class_id]
|
|
print(f" Class ID: {class_id}, Key: {key}, Character: '{char}'")
|
|
if num_classes > 10:
|
|
print(" ...")
|
|
for i in range(max(5, num_classes - 5), num_classes):
|
|
key = example_keys[i]
|
|
class_id = unicode_to_classid[key]
|
|
char = classid_to_char[class_id]
|
|
print(f" Class ID: {class_id}, Key: {key}, Character: '{char}'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("扫描图像文件...")
|
|
|
|
|
|
all_image_paths = []
|
|
all_labels = []
|
|
|
|
|
|
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.gif']
|
|
|
|
|
|
print(f"正在扫描根目录: {data_root}")
|
|
|
|
|
|
found_folders = 0
|
|
processed_folders = 0
|
|
|
|
|
|
|
|
try:
|
|
root_contents = os.listdir(data_root)
|
|
print(f"在 {data_root} 中找到 {len(root_contents)} 个项目 (文件/文件夹)")
|
|
|
|
except FileNotFoundError:
|
|
print(f"错误:无法访问指定的 data_root 目录: {data_root}。请确保路径正确且程序有权限访问。")
|
|
exit()
|
|
except Exception as e:
|
|
print(f"访问目录 {data_root} 时发生未知错误: {e}")
|
|
exit()
|
|
|
|
|
|
|
|
|
|
for item_name in tqdm(root_contents, desc="扫描项目"):
|
|
item_path = os.path.join(data_root, item_name)
|
|
|
|
if os.path.isdir(item_path):
|
|
found_folders += 1
|
|
|
|
unicode_index_folder_padded = item_name
|
|
|
|
try:
|
|
|
|
|
|
unicode_index_int = int(unicode_index_folder_padded)
|
|
unicode_index_key = str(unicode_index_int)
|
|
|
|
|
|
if unicode_index_key in unicode_to_classid:
|
|
processed_folders += 1
|
|
class_id = unicode_to_classid[unicode_index_key]
|
|
image_count_in_folder = 0
|
|
for ext in image_extensions:
|
|
|
|
|
|
folder_images = glob(os.path.join(item_path, ext))
|
|
for img_path in folder_images:
|
|
all_image_paths.append(img_path)
|
|
all_labels.append(class_id)
|
|
image_count_in_folder += 1
|
|
|
|
|
|
else:
|
|
|
|
if processed_folders < 5 and found_folders <= 20:
|
|
print(f"警告: 从文件夹名称 '{unicode_index_folder_padded}' 导出的键 '{unicode_index_key}' 在 Char_dict.json 中没有对应的条目,将跳过此文件夹。")
|
|
|
|
except ValueError:
|
|
|
|
if found_folders <= 10:
|
|
print(f"信息: 跳过非预期格式的文件夹/文件: '{unicode_index_folder_padded}'")
|
|
except Exception as e:
|
|
print(f"处理文件夹 {item_name} 时发生错误: {e}")
|
|
|
|
|
|
|
|
print(f"总共扫描了 {found_folders} 个文件夹。")
|
|
print(f"其中 {processed_folders} 个文件夹的名称与 Char_dict.json 中的键成功匹配。")
|
|
|
|
|
|
if not all_image_paths:
|
|
print(f"错误: 在目录 '{data_root}' 下及其与 JSON 匹配的子文件夹中,没有找到任何支持的图像文件。")
|
|
print(f"请再次检查:")
|
|
print(f"1. 路径 '{data_root}' 是否正确?")
|
|
print(f"2. 匹配的文件夹 (如 {processed_folders} 个) 中是否真的包含 {', '.join(image_extensions)} 格式的图片文件?")
|
|
print(f"3. 文件权限是否允许读取?")
|
|
exit()
|
|
|
|
print(f"总共找到 {len(all_image_paths)} 个图像文件。")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("划分数据集 (Train/Validation/Test)...")
|
|
|
|
|
|
|
|
assert len(all_image_paths) == len(all_labels)
|
|
labels_int = [int(label) for label in all_labels]
|
|
|
|
|
|
|
|
|
|
train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
|
|
all_image_paths,
|
|
labels_int,
|
|
test_size=test_size,
|
|
random_state=random_seed,
|
|
stratify=labels_int
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
relative_val_size = val_size / (1.0 - test_size)
|
|
|
|
train_paths, val_paths, train_labels, val_labels = train_test_split(
|
|
train_val_paths,
|
|
train_val_labels,
|
|
test_size=relative_val_size,
|
|
random_state=random_seed,
|
|
stratify=train_val_labels
|
|
)
|
|
|
|
print(f"划分结果:")
|
|
print(f" 训练集: {len(train_paths)} 张图片")
|
|
print(f" 验证集: {len(val_paths)} 张图片")
|
|
print(f" 测试集: {len(test_paths)} 张图片")
|
|
|
|
|
|
print(f"创建 YOLO 格式数据集到 '{output_dir}'...")
|
|
|
|
|
|
|
|
sets = {
|
|
'train': (train_paths, train_labels),
|
|
'val': (val_paths, val_labels),
|
|
'test': (test_paths, test_labels)
|
|
}
|
|
|
|
if os.path.exists(output_dir):
|
|
print(f"警告: 输出目录 '{output_dir}' 已存在,将清空并重新创建。")
|
|
shutil.rmtree(output_dir)
|
|
|
|
|
|
os.makedirs(os.path.join(output_dir, 'images', 'train'), exist_ok=True)
|
|
os.makedirs(os.path.join(output_dir, 'labels', 'train'), exist_ok=True)
|
|
os.makedirs(os.path.join(output_dir, 'images', 'val'), exist_ok=True)
|
|
os.makedirs(os.path.join(output_dir, 'labels', 'val'), exist_ok=True)
|
|
os.makedirs(os.path.join(output_dir, 'images', 'test'), exist_ok=True)
|
|
os.makedirs(os.path.join(output_dir, 'labels', 'test'), exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for set_name, (paths, labels) in sets.items():
|
|
print(f"处理 {set_name} 集...")
|
|
image_dir = os.path.join(output_dir, 'images', set_name)
|
|
label_dir = os.path.join(output_dir, 'labels', set_name)
|
|
|
|
for img_path, class_id in tqdm(zip(paths, labels), total=len(paths), desc=f" 拷贝和生成标注 ({set_name})"):
|
|
|
|
|
|
base_filename = os.path.basename(img_path)
|
|
|
|
filename_no_ext = os.path.splitext(base_filename)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parent_folder_name = os.path.basename(os.path.dirname(img_path))
|
|
unique_filename_no_ext = f"{parent_folder_name}_{filename_no_ext}"
|
|
unique_base_filename = f"{unique_filename_no_ext}{os.path.splitext(base_filename)[1]}"
|
|
|
|
|
|
dest_img_path = os.path.join(image_dir, unique_base_filename)
|
|
|
|
dest_label_path = os.path.join(label_dir, f"{unique_filename_no_ext}.txt")
|
|
|
|
|
|
try:
|
|
shutil.copy2(img_path, dest_img_path)
|
|
except Exception as e:
|
|
print(f"\n错误: 无法复制文件 {img_path} 到 {dest_img_path}. 错误: {e}")
|
|
continue
|
|
|
|
|
|
annotation_line = f"{class_id} {bbox_annotation}\n"
|
|
try:
|
|
with open(dest_label_path, 'w', encoding='utf-8') as f_label:
|
|
f_label.write(annotation_line)
|
|
except Exception as e:
|
|
print(f"\n错误: 无法写入标注文件 {dest_label_path}. 错误: {e}")
|
|
|
|
if os.path.exists(dest_img_path):
|
|
os.remove(dest_img_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("生成 dataset.yaml...")
|
|
|
|
|
|
class_names = [classid_to_char[i] for i in range(num_classes)]
|
|
|
|
|
|
|
|
|
|
yaml_content = f"""
|
|
# YOLOv11 dataset configuration file
|
|
|
|
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
|
|
path: {os.path.abspath(output_dir)} # dataset root dir (绝对路径通常更可靠)
|
|
train: images/train # train images (relative to 'path')
|
|
val: images/val # val images (relative to 'path')
|
|
test: images/test # test images (optional)
|
|
|
|
# Classes
|
|
nc: {num_classes} # number of classes
|
|
names: {json.dumps(class_names, ensure_ascii=False)} # Use json.dumps for proper list formatting and handling non-ASCII characters
|
|
|
|
"""
|
|
|
|
yaml_path = os.path.join(output_dir, 'dataset.yaml')
|
|
try:
|
|
with open(yaml_path, 'w', encoding='utf-8') as f_yaml:
|
|
f_yaml.write(yaml_content)
|
|
except Exception as e:
|
|
print(f"错误: 无法写入 dataset.yaml 文件到 {yaml_path}. 错误: {e}")
|
|
|
|
print("-" * 30)
|
|
print("数据集准备完成!")
|
|
print(f"YOLO 格式的数据集已生成在: {os.path.abspath(output_dir)}")
|
|
print(f"配置文件为: {os.path.abspath(yaml_path)}")
|
|
print("-" * 30)
|
|
print("重要提示:")
|
|
print("1. 每个图像只包含一个居中的汉字,并为其生成了覆盖整个图像的边界框。")
|
|
print("2. 检查 `dataset.yaml` 文件中的路径是否正确,特别是 `path` 字段。根据你的训练环境可能需要调整。")
|
|
print("3. 检查生成的 `labels` 文件夹中的 .txt 文件内容是否符合预期格式。")
|
|
print("4. 在开始训练前,建议随机抽查几个图片及其对应的标注文件,确保它们正确对应。")
|
|
print("5. 确保 `Char_dict.json` 中的汉字编码与你的系统或训练环境兼容。YAML 文件已使用 UTF-8 编码保存。")
|
|
print("-" * 30)
|
|
|