chore: 加载时将图像缩放到224*224, resnet禁用crop
This commit is contained in:
@@ -14,9 +14,9 @@ freeze_backbone: true # 冻结ResNet参数,只训练后面的pool和out层(
|
||||
# ====================
|
||||
# 输入配置
|
||||
# ====================
|
||||
input_shape: [3, 96, 96] # 输入图像形状 (C, H, W)
|
||||
crop_shape: [84, 84] # 裁剪后的图像形状 (H, W)
|
||||
crop_is_random: true # 训练时使用随机裁剪,评估时使用中心裁剪
|
||||
input_shape: [3, 224, 224] # 输入图像形状 (C, H, W) - ImageNet标准尺寸
|
||||
crop_shape: null # 裁剪后的图像形状 (H, W) - 设为null禁用裁剪
|
||||
crop_is_random: true # 训练时使用随机裁剪,评估时使用中心裁剪(crop_shape=null时无效)
|
||||
|
||||
# ====================
|
||||
# 归一化和特征提取
|
||||
|
||||
@@ -86,6 +86,9 @@ class SimpleRobotDataset(Dataset):
|
||||
h5_path = f'observations/images/{cam_name}'
|
||||
if h5_path in f:
|
||||
img = f[h5_path][meta["frame_idx"]]
|
||||
# Resize图像到224x224(减少内存和I/O负担)
|
||||
import cv2
|
||||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
||||
# 转换为float并归一化到 [0, 1]
|
||||
img = torch.from_numpy(img).float() / 255.0
|
||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||
|
||||
Reference in New Issue
Block a user