75 lines
1.8 KiB
Bash
Executable File
75 lines
1.8 KiB
Bash
Executable File
#!/usr/bin/env bash
|
|
set -euo pipefail
|
|
|
|
DEVICE="cuda"
|
|
EPOCHS=2000
|
|
STEPS_PER_EPOCH=200
|
|
BATCH_SIZE=512
|
|
SEQ_LEN=5
|
|
LR=1e-3
|
|
WEIGHT_DECAY=1e-2
|
|
LAMBDA_FLOW=1.0
|
|
LAMBDA_PERCEPTUAL=0.4
|
|
NUM_CLASSES=10
|
|
IMAGE_SIZE=28
|
|
CHANNELS=1
|
|
NUM_WORKERS=32
|
|
DATASET_NAME="ylecun/mnist"
|
|
DATASET_SPLIT="train"
|
|
D_MODEL=784
|
|
N_LAYER=8
|
|
D_STATE=32
|
|
D_CONV=4
|
|
EXPAND=2
|
|
HEADDIM=32
|
|
CHUNK_SIZE=1
|
|
USE_RESIDUAL=true
|
|
USE_DDP=true
|
|
VAL_EVERY=1000
|
|
VAL_SAMPLES_PER_CLASS=8
|
|
VAL_GRID_ROWS=4
|
|
VAL_SAMPLING_STEPS=5
|
|
TIME_GRID_SIZE=256
|
|
PROJECT="as-mamba-mnist"
|
|
RUN_NAME="mnist-meanflow-xpred"
|
|
OUTPUT_DIR="outputs"
|
|
|
|
USE_RESIDUAL_FLAG="--use-residual"
|
|
if [ "${USE_RESIDUAL}" = "false" ]; then USE_RESIDUAL_FLAG="--no-use-residual"; fi
|
|
USE_DDP_FLAG="--use-ddp"
|
|
if [ "${USE_DDP}" = "false" ]; then USE_DDP_FLAG="--no-use-ddp"; fi
|
|
|
|
uv run torchrun --nproc_per_node=2 main.py \
|
|
--device "${DEVICE}" \
|
|
--epochs "${EPOCHS}" \
|
|
--steps-per-epoch "${STEPS_PER_EPOCH}" \
|
|
--batch-size "${BATCH_SIZE}" \
|
|
--seq-len "${SEQ_LEN}" \
|
|
--lr "${LR}" \
|
|
--weight-decay "${WEIGHT_DECAY}" \
|
|
--lambda-flow "${LAMBDA_FLOW}" \
|
|
--lambda-perceptual "${LAMBDA_PERCEPTUAL}" \
|
|
--num-classes "${NUM_CLASSES}" \
|
|
--image-size "${IMAGE_SIZE}" \
|
|
--channels "${CHANNELS}" \
|
|
--num-workers "${NUM_WORKERS}" \
|
|
--dataset-name "${DATASET_NAME}" \
|
|
--dataset-split "${DATASET_SPLIT}" \
|
|
--d-model "${D_MODEL}" \
|
|
--n-layer "${N_LAYER}" \
|
|
--d-state "${D_STATE}" \
|
|
--d-conv "${D_CONV}" \
|
|
--expand "${EXPAND}" \
|
|
--headdim "${HEADDIM}" \
|
|
--chunk-size "${CHUNK_SIZE}" \
|
|
${USE_RESIDUAL_FLAG} \
|
|
${USE_DDP_FLAG} \
|
|
--val-every "${VAL_EVERY}" \
|
|
--val-samples-per-class "${VAL_SAMPLES_PER_CLASS}" \
|
|
--val-grid-rows "${VAL_GRID_ROWS}" \
|
|
--val-sampling-steps "${VAL_SAMPLING_STEPS}" \
|
|
--time-grid-size "${TIME_GRID_SIZE}" \
|
|
--project "${PROJECT}" \
|
|
--run-name "${RUN_NAME}" \
|
|
--output-dir "${OUTPUT_DIR}"
|