Files
mamba_diffusion/train_as_mamba.sh
2026-01-21 15:14:04 +08:00

63 lines
1.5 KiB
Bash
Executable File

#!/usr/bin/env bash
set -euo pipefail
DEVICE="cuda"
EPOCHS=80
STEPS_PER_EPOCH=100
BATCH_SIZE=256
SEQ_LEN=20
LR=1e-3
DT_MIN=1e-3
DT_MAX=0.10
DT_ALPHA=6.0
LAMBDA_FLOW=1.0
LAMBDA_POS=0.0
LAMBDA_DT=0.5
USE_FLOW_LOSS=true
USE_POS_LOSS=false
USE_DT_LOSS=true
VAL_EVERY=200
VAL_SAMPLES=512
VAL_PLOT_SAMPLES=16
VAL_MAX_STEPS=100
CENTER_MIN=-8
CENTER_MAX=8
CENTER_DISTANCE_MIN=8
PROJECT="as-mamba"
RUN_NAME="sphere-to-sphere-dt"
OUTPUT_DIR="outputs"
USE_FLOW_FLAG="--use-flow-loss"
if [ "${USE_FLOW_LOSS}" = "false" ]; then USE_FLOW_FLAG="--no-use-flow-loss"; fi
USE_POS_FLAG="--use-pos-loss"
if [ "${USE_POS_LOSS}" = "false" ]; then USE_POS_FLAG="--no-use-pos-loss"; fi
USE_DT_FLAG="--use-dt-loss"
if [ "${USE_DT_LOSS}" = "false" ]; then USE_DT_FLAG="--no-use-dt-loss"; fi
uv run python main.py \
--device "${DEVICE}" \
--epochs "${EPOCHS}" \
--steps-per-epoch "${STEPS_PER_EPOCH}" \
--batch-size "${BATCH_SIZE}" \
--seq-len "${SEQ_LEN}" \
--lr "${LR}" \
--dt-min "${DT_MIN}" \
--dt-max "${DT_MAX}" \
--dt-alpha "${DT_ALPHA}" \
--lambda-flow "${LAMBDA_FLOW}" \
--lambda-pos "${LAMBDA_POS}" \
--lambda-dt "${LAMBDA_DT}" \
${USE_FLOW_FLAG} \
${USE_POS_FLAG} \
${USE_DT_FLAG} \
--val-every "${VAL_EVERY}" \
--val-samples "${VAL_SAMPLES}" \
--val-plot-samples "${VAL_PLOT_SAMPLES}" \
--val-max-steps "${VAL_MAX_STEPS}" \
--center-min "${CENTER_MIN}" \
--center-max "${CENTER_MAX}" \
--center-distance-min "${CENTER_DISTANCE_MIN}" \
--project "${PROJECT}" \
--run-name "${RUN_NAME}" \
--output-dir "${OUTPUT_DIR}"