Implement Mamba MeanFlow x-prediction training

This commit is contained in:
Logic
2026-03-11 16:33:40 +08:00
parent 01fc1e4eab
commit 9b2968997c
5 changed files with 353 additions and 121 deletions

View File

@@ -5,13 +5,11 @@ DEVICE="cuda"
EPOCHS=2000
STEPS_PER_EPOCH=200
BATCH_SIZE=512
SEQ_LEN=1
SEQ_LEN=5
LR=1e-3
WEIGHT_DECAY=1e-2
DT_MIN=5e-4
DT_MAX=1.1
DT_ALPHA=9.0
LAMBDA_FLOW=1.0
LAMBDA_PERCEPTUAL=0.4
NUM_CLASSES=10
IMAGE_SIZE=28
CHANNELS=1
@@ -30,8 +28,10 @@ USE_DDP=true
VAL_EVERY=1000
VAL_SAMPLES_PER_CLASS=8
VAL_GRID_ROWS=4
VAL_SAMPLING_STEPS=5
TIME_GRID_SIZE=256
PROJECT="as-mamba-mnist"
RUN_NAME="mnist-flow-res-5seq"
RUN_NAME="mnist-meanflow-xpred"
OUTPUT_DIR="outputs"
USE_RESIDUAL_FLAG="--use-residual"
@@ -47,10 +47,8 @@ uv run torchrun --nproc_per_node=2 main.py \
--seq-len "${SEQ_LEN}" \
--lr "${LR}" \
--weight-decay "${WEIGHT_DECAY}" \
--dt-min "${DT_MIN}" \
--dt-max "${DT_MAX}" \
--dt-alpha "${DT_ALPHA}" \
--lambda-flow "${LAMBDA_FLOW}" \
--lambda-perceptual "${LAMBDA_PERCEPTUAL}" \
--num-classes "${NUM_CLASSES}" \
--image-size "${IMAGE_SIZE}" \
--channels "${CHANNELS}" \
@@ -69,6 +67,8 @@ uv run torchrun --nproc_per_node=2 main.py \
--val-every "${VAL_EVERY}" \
--val-samples-per-class "${VAL_SAMPLES_PER_CLASS}" \
--val-grid-rows "${VAL_GRID_ROWS}" \
--val-sampling-steps "${VAL_SAMPLING_STEPS}" \
--time-grid-size "${TIME_GRID_SIZE}" \
--project "${PROJECT}" \
--run-name "${RUN_NAME}" \
--output-dir "${OUTPUT_DIR}"