#!/usr/bin/env bash set -euo pipefail DEVICE="cuda" EPOCHS=2000 STEPS_PER_EPOCH=200 BATCH_SIZE=256 SEQ_LEN=100 LR=2e-3 WEIGHT_DECAY=1e-2 DT_MIN=5e-4 DT_MAX=0.06 DT_ALPHA=9.0 LAMBDA_FLOW=1.0 LAMBDA_POS=1.0 LAMBDA_DT=1.0 USE_FLOW_LOSS=true USE_POS_LOSS=false USE_DT_LOSS=true NUM_CLASSES=10 IMAGE_SIZE=28 CHANNELS=1 NUM_WORKERS=16 DATASET_NAME="ylecun/mnist" DATASET_SPLIT="train" D_MODEL=784 N_LAYER=6 D_STATE=32 D_CONV=4 EXPAND=2 HEADDIM=32 CHUNK_SIZE=20 USE_RESIDUAL=false USE_DDP=true VAL_EVERY=1000 VAL_SAMPLES_PER_CLASS=8 VAL_GRID_ROWS=4 VAL_MAX_STEPS=0 PROJECT="as-mamba-mnist" RUN_NAME="mnist-flow" OUTPUT_DIR="outputs" USE_FLOW_FLAG="--use-flow-loss" if [ "${USE_FLOW_LOSS}" = "false" ]; then USE_FLOW_FLAG="--no-use-flow-loss"; fi USE_POS_FLAG="--use-pos-loss" if [ "${USE_POS_LOSS}" = "false" ]; then USE_POS_FLAG="--no-use-pos-loss"; fi USE_DT_FLAG="--use-dt-loss" if [ "${USE_DT_LOSS}" = "false" ]; then USE_DT_FLAG="--no-use-dt-loss"; fi USE_RESIDUAL_FLAG="--use-residual" if [ "${USE_RESIDUAL}" = "false" ]; then USE_RESIDUAL_FLAG="--no-use-residual"; fi USE_DDP_FLAG="--use-ddp" if [ "${USE_DDP}" = "false" ]; then USE_DDP_FLAG="--no-use-ddp"; fi uv run torchrun --nproc_per_node=2 main.py \ --device "${DEVICE}" \ --epochs "${EPOCHS}" \ --steps-per-epoch "${STEPS_PER_EPOCH}" \ --batch-size "${BATCH_SIZE}" \ --seq-len "${SEQ_LEN}" \ --lr "${LR}" \ --weight-decay "${WEIGHT_DECAY}" \ --dt-min "${DT_MIN}" \ --dt-max "${DT_MAX}" \ --dt-alpha "${DT_ALPHA}" \ --lambda-flow "${LAMBDA_FLOW}" \ --lambda-pos "${LAMBDA_POS}" \ --lambda-dt "${LAMBDA_DT}" \ ${USE_FLOW_FLAG} \ ${USE_POS_FLAG} \ ${USE_DT_FLAG} \ --num-classes "${NUM_CLASSES}" \ --image-size "${IMAGE_SIZE}" \ --channels "${CHANNELS}" \ --num-workers "${NUM_WORKERS}" \ --dataset-name "${DATASET_NAME}" \ --dataset-split "${DATASET_SPLIT}" \ --d-model "${D_MODEL}" \ --n-layer "${N_LAYER}" \ --d-state "${D_STATE}" \ --d-conv "${D_CONV}" \ --expand "${EXPAND}" \ --headdim "${HEADDIM}" \ --chunk-size "${CHUNK_SIZE}" \ ${USE_RESIDUAL_FLAG} \ ${USE_DDP_FLAG} \ --val-every "${VAL_EVERY}" \ --val-samples-per-class "${VAL_SAMPLES_PER_CLASS}" \ --val-grid-rows "${VAL_GRID_ROWS}" \ --val-max-steps "${VAL_MAX_STEPS}" \ --project "${PROJECT}" \ --run-name "${RUN_NAME}" \ --output-dir "${OUTPUT_DIR}"