更改默认参数
This commit is contained in:
@@ -25,8 +25,8 @@ class VLAAgent(nn.Module):
|
||||
inference_steps=10, # DDIM 推理步数
|
||||
num_cams=3, # 视觉输入的摄像头数量
|
||||
dataset_stats=None, # 数据集统计信息,用于归一化
|
||||
normalization_type='gaussian', # 归一化类型: 'gaussian' 或 'min_max'
|
||||
num_action_steps=1, # 每次推理实际执行多少步动作
|
||||
normalization_type='min_max', # 归一化类型: 'gaussian' 或 'min_max'
|
||||
num_action_steps=8, # 每次推理实际执行多少步动作
|
||||
):
|
||||
super().__init__()
|
||||
# 保存参数
|
||||
|
||||
@@ -9,19 +9,19 @@ defaults:
|
||||
# ====================
|
||||
train:
|
||||
# 基础训练参数
|
||||
batch_size: 8 # 批次大小
|
||||
batch_size: 32 # 批次大小
|
||||
lr: 1e-4 # 学习率
|
||||
max_steps: 100000 # 最大训练步数
|
||||
device: "cuda" # 设备: "cuda" 或 "cpu"
|
||||
|
||||
# 数据加载
|
||||
num_workers: 8 # DataLoader 工作进程数(调试时设为 0,生产环境用 8)
|
||||
num_workers: 40 # DataLoader 工作进程数(调试时设为 0,生产环境用 8)
|
||||
val_split: 0.1 # 验证集比例
|
||||
seed: 42 # 随机种子(用于数据划分)
|
||||
|
||||
# 日志和检查点
|
||||
log_freq: 100 # 日志记录频率(步数)
|
||||
save_freq: 5000 # 保存检查点频率(步数)
|
||||
save_freq: 2000 # 保存检查点频率(步数)
|
||||
|
||||
# 学习率调度器(带预热)
|
||||
warmup_steps: 500 # 预热步数
|
||||
|
||||
@@ -21,7 +21,7 @@ class NormalizationModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
stats: Optional[Dict] = None,
|
||||
normalization_type: Literal['gaussian', 'min_max'] = 'gaussian'
|
||||
normalization_type: Literal['gaussian', 'min_max'] = 'min_max'
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user