Merge branch 'dev' of gitlab.com:leeeezd0016-group/gouhanke-vla into dev

This commit is contained in:
gouhanke
2026-02-11 17:20:21 +08:00
3 changed files with 6 additions and 6 deletions

View File

@@ -25,8 +25,8 @@ class VLAAgent(nn.Module):
inference_steps=10, # DDIM 推理步数 inference_steps=10, # DDIM 推理步数
num_cams=3, # 视觉输入的摄像头数量 num_cams=3, # 视觉输入的摄像头数量
dataset_stats=None, # 数据集统计信息,用于归一化 dataset_stats=None, # 数据集统计信息,用于归一化
normalization_type='gaussian', # 归一化类型: 'gaussian' 或 'min_max' normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
num_action_steps=1, # 每次推理实际执行多少步动作 num_action_steps=8, # 每次推理实际执行多少步动作
): ):
super().__init__() super().__init__()
# 保存参数 # 保存参数

View File

@@ -9,19 +9,19 @@ defaults:
# ==================== # ====================
train: train:
# 基础训练参数 # 基础训练参数
batch_size: 8 # 批次大小 batch_size: 32 # 批次大小
lr: 1e-4 # 学习率 lr: 1e-4 # 学习率
max_steps: 100000 # 最大训练步数 max_steps: 100000 # 最大训练步数
device: "cuda" # 设备: "cuda" 或 "cpu" device: "cuda" # 设备: "cuda" 或 "cpu"
# 数据加载 # 数据加载
num_workers: 8 # DataLoader 工作进程数(调试时设为 0生产环境用 8 num_workers: 40 # DataLoader 工作进程数(调试时设为 0生产环境用 8
val_split: 0.1 # 验证集比例 val_split: 0.1 # 验证集比例
seed: 42 # 随机种子(用于数据划分) seed: 42 # 随机种子(用于数据划分)
# 日志和检查点 # 日志和检查点
log_freq: 100 # 日志记录频率(步数) log_freq: 100 # 日志记录频率(步数)
save_freq: 5000 # 保存检查点频率(步数) save_freq: 2000 # 保存检查点频率(步数)
# 学习率调度器(带预热) # 学习率调度器(带预热)
warmup_steps: 500 # 预热步数 warmup_steps: 500 # 预热步数

View File

@@ -21,7 +21,7 @@ class NormalizationModule(nn.Module):
def __init__( def __init__(
self, self,
stats: Optional[Dict] = None, stats: Optional[Dict] = None,
normalization_type: Literal['gaussian', 'min_max'] = 'gaussian' normalization_type: Literal['gaussian', 'min_max'] = 'min_max'
): ):
""" """
Args: Args: