Files
mamba_diffusion/train_as_mamba.sh

93 lines
2.3 KiB
Bash
Executable File

#!/usr/bin/env bash
set -euo pipefail
DEVICE="cuda"
EPOCHS=2000
STEPS_PER_EPOCH=200
BATCH_SIZE=256
SEQ_LEN=100
LR=2e-3
WEIGHT_DECAY=1e-2
DT_MIN=5e-4
DT_MAX=0.06
DT_ALPHA=9.0
LAMBDA_FLOW=1.0
LAMBDA_POS=1.0
LAMBDA_DT=1.0
USE_FLOW_LOSS=true
USE_POS_LOSS=false
USE_DT_LOSS=true
NUM_CLASSES=10
IMAGE_SIZE=28
CHANNELS=1
NUM_WORKERS=16
DATASET_NAME="ylecun/mnist"
DATASET_SPLIT="train"
D_MODEL=784
N_LAYER=6
D_STATE=32
D_CONV=4
EXPAND=2
HEADDIM=32
CHUNK_SIZE=20
USE_RESIDUAL=false
USE_DDP=true
VAL_EVERY=1000
VAL_SAMPLES_PER_CLASS=8
VAL_GRID_ROWS=4
VAL_MAX_STEPS=0
PROJECT="as-mamba-mnist"
RUN_NAME="mnist-flow"
OUTPUT_DIR="outputs"
USE_FLOW_FLAG="--use-flow-loss"
if [ "${USE_FLOW_LOSS}" = "false" ]; then USE_FLOW_FLAG="--no-use-flow-loss"; fi
USE_POS_FLAG="--use-pos-loss"
if [ "${USE_POS_LOSS}" = "false" ]; then USE_POS_FLAG="--no-use-pos-loss"; fi
USE_DT_FLAG="--use-dt-loss"
if [ "${USE_DT_LOSS}" = "false" ]; then USE_DT_FLAG="--no-use-dt-loss"; fi
USE_RESIDUAL_FLAG="--use-residual"
if [ "${USE_RESIDUAL}" = "false" ]; then USE_RESIDUAL_FLAG="--no-use-residual"; fi
USE_DDP_FLAG="--use-ddp"
if [ "${USE_DDP}" = "false" ]; then USE_DDP_FLAG="--no-use-ddp"; fi
uv run torchrun --nproc_per_node=2 main.py \
--device "${DEVICE}" \
--epochs "${EPOCHS}" \
--steps-per-epoch "${STEPS_PER_EPOCH}" \
--batch-size "${BATCH_SIZE}" \
--seq-len "${SEQ_LEN}" \
--lr "${LR}" \
--weight-decay "${WEIGHT_DECAY}" \
--dt-min "${DT_MIN}" \
--dt-max "${DT_MAX}" \
--dt-alpha "${DT_ALPHA}" \
--lambda-flow "${LAMBDA_FLOW}" \
--lambda-pos "${LAMBDA_POS}" \
--lambda-dt "${LAMBDA_DT}" \
${USE_FLOW_FLAG} \
${USE_POS_FLAG} \
${USE_DT_FLAG} \
--num-classes "${NUM_CLASSES}" \
--image-size "${IMAGE_SIZE}" \
--channels "${CHANNELS}" \
--num-workers "${NUM_WORKERS}" \
--dataset-name "${DATASET_NAME}" \
--dataset-split "${DATASET_SPLIT}" \
--d-model "${D_MODEL}" \
--n-layer "${N_LAYER}" \
--d-state "${D_STATE}" \
--d-conv "${D_CONV}" \
--expand "${EXPAND}" \
--headdim "${HEADDIM}" \
--chunk-size "${CHUNK_SIZE}" \
${USE_RESIDUAL_FLAG} \
${USE_DDP_FLAG} \
--val-every "${VAL_EVERY}" \
--val-samples-per-class "${VAL_SAMPLES_PER_CLASS}" \
--val-grid-rows "${VAL_GRID_ROWS}" \
--val-max-steps "${VAL_MAX_STEPS}" \
--project "${PROJECT}" \
--run-name "${RUN_NAME}" \
--output-dir "${OUTPUT_DIR}"