chore: 删除多余脚本
This commit is contained in:
@@ -1,75 +0,0 @@
|
||||
# 图像预处理
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from typing import Union, List
|
||||
|
||||
class VLAImageProcessor:
|
||||
"""
|
||||
VLA 图像预处理器,专为 SigLIP/CLIP 等 ViT 架构设计。
|
||||
功能:
|
||||
1. Numpy (HWC) -> Tensor (CHW)
|
||||
2. Resize (e.g., 384x384)
|
||||
3. Normalize (SigLIP: mean=0.5, std=0.5)
|
||||
4. Data Augmentation (训练时开启颜色抖动)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int = 384,
|
||||
mean: List[float] = [0.5, 0.5, 0.5],
|
||||
std: List[float] = [0.5, 0.5, 0.5],
|
||||
enable_augmentation: bool = True,
|
||||
aug_strength: float = 0.1 # 增强强度,0.1~0.2 比较安全
|
||||
):
|
||||
self.resolution = resolution
|
||||
self.enable_augmentation = enable_augmentation
|
||||
|
||||
# --- 1. 基础处理 (所有模式通用) ---
|
||||
# 注意:这里我们分步定义,因为增强通常在 PIL 阶段做比较快
|
||||
self.resize = T.Resize((resolution, resolution), interpolation=T.InterpolationMode.BICUBIC, antialias=True)
|
||||
self.to_tensor = T.ToTensor()
|
||||
self.normalize = T.Normalize(mean=mean, std=std)
|
||||
|
||||
# --- 2. 数据增强 (仅训练用) ---
|
||||
# 机器人学习通常不做 RandomCrop (会丢失绝对坐标信息),主要做颜色增强
|
||||
if enable_augmentation:
|
||||
self.aug = T.ColorJitter(
|
||||
brightness=aug_strength,
|
||||
contrast=aug_strength,
|
||||
saturation=aug_strength,
|
||||
hue=aug_strength / 2
|
||||
)
|
||||
else:
|
||||
self.aug = torch.nn.Identity()
|
||||
|
||||
def __call__(self, img: Union[np.ndarray, Image.Image, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
img: (H, W, C) uint8 numpy array (from HDF5) OR PIL Image
|
||||
Returns:
|
||||
tensor: (C, H, W) float32, Normalized
|
||||
"""
|
||||
# 1. 统一转为 PIL Image (方便做 Resize 和 Jitter)
|
||||
if isinstance(img, np.ndarray):
|
||||
img = Image.fromarray(img)
|
||||
elif isinstance(img, torch.Tensor):
|
||||
# 假设 Tensor 是 CHW,转回 PIL 比较麻烦,通常 HDF5 出来都是 numpy
|
||||
pass
|
||||
|
||||
# 2. 数据增强 (如果开启)
|
||||
if self.enable_augmentation:
|
||||
img = self.aug(img)
|
||||
|
||||
# 3. 调整尺寸
|
||||
img = self.resize(img)
|
||||
|
||||
# 4. 转张量 & 归一化
|
||||
# ToTensor 会把 [0, 255] -> [0.0, 1.0]
|
||||
tensor = self.to_tensor(img)
|
||||
tensor = self.normalize(tensor)
|
||||
|
||||
return tensor
|
||||
|
||||
def __repr__(self):
|
||||
return f"VLAImageProcessor(res={self.resolution}, aug={self.enable_augmentation})"
|
||||
@@ -1 +0,0 @@
|
||||
# 文本 Tokenizer 包装
|
||||
Reference in New Issue
Block a user