feat: migrate switch to conditional flow matching from sphere trajectory
This commit is contained in:
@@ -2,29 +2,42 @@
|
||||
set -euo pipefail
|
||||
|
||||
DEVICE="cuda"
|
||||
EPOCHS=80
|
||||
STEPS_PER_EPOCH=100
|
||||
EPOCHS=2000
|
||||
STEPS_PER_EPOCH=200
|
||||
BATCH_SIZE=256
|
||||
SEQ_LEN=20
|
||||
LR=1e-3
|
||||
DT_MIN=1e-3
|
||||
DT_MAX=0.10
|
||||
DT_ALPHA=6.0
|
||||
SEQ_LEN=100
|
||||
LR=2e-3
|
||||
WEIGHT_DECAY=1e-2
|
||||
DT_MIN=5e-4
|
||||
DT_MAX=0.06
|
||||
DT_ALPHA=9.0
|
||||
LAMBDA_FLOW=1.0
|
||||
LAMBDA_POS=0.0
|
||||
LAMBDA_DT=0.5
|
||||
LAMBDA_POS=1.0
|
||||
LAMBDA_DT=1.0
|
||||
USE_FLOW_LOSS=true
|
||||
USE_POS_LOSS=false
|
||||
USE_DT_LOSS=true
|
||||
VAL_EVERY=200
|
||||
VAL_SAMPLES=512
|
||||
VAL_PLOT_SAMPLES=16
|
||||
VAL_MAX_STEPS=100
|
||||
CENTER_MIN=-8
|
||||
CENTER_MAX=8
|
||||
CENTER_DISTANCE_MIN=8
|
||||
PROJECT="as-mamba"
|
||||
RUN_NAME="sphere-to-sphere-dt"
|
||||
NUM_CLASSES=10
|
||||
IMAGE_SIZE=28
|
||||
CHANNELS=1
|
||||
NUM_WORKERS=16
|
||||
DATASET_NAME="ylecun/mnist"
|
||||
DATASET_SPLIT="train"
|
||||
D_MODEL=784
|
||||
N_LAYER=6
|
||||
D_STATE=32
|
||||
D_CONV=4
|
||||
EXPAND=2
|
||||
HEADDIM=32
|
||||
CHUNK_SIZE=20
|
||||
USE_RESIDUAL=false
|
||||
USE_DDP=true
|
||||
VAL_EVERY=1000
|
||||
VAL_SAMPLES_PER_CLASS=8
|
||||
VAL_GRID_ROWS=4
|
||||
VAL_MAX_STEPS=0
|
||||
PROJECT="as-mamba-mnist"
|
||||
RUN_NAME="mnist-flow"
|
||||
OUTPUT_DIR="outputs"
|
||||
|
||||
USE_FLOW_FLAG="--use-flow-loss"
|
||||
@@ -33,14 +46,19 @@ USE_POS_FLAG="--use-pos-loss"
|
||||
if [ "${USE_POS_LOSS}" = "false" ]; then USE_POS_FLAG="--no-use-pos-loss"; fi
|
||||
USE_DT_FLAG="--use-dt-loss"
|
||||
if [ "${USE_DT_LOSS}" = "false" ]; then USE_DT_FLAG="--no-use-dt-loss"; fi
|
||||
USE_RESIDUAL_FLAG="--use-residual"
|
||||
if [ "${USE_RESIDUAL}" = "false" ]; then USE_RESIDUAL_FLAG="--no-use-residual"; fi
|
||||
USE_DDP_FLAG="--use-ddp"
|
||||
if [ "${USE_DDP}" = "false" ]; then USE_DDP_FLAG="--no-use-ddp"; fi
|
||||
|
||||
uv run python main.py \
|
||||
uv run torchrun --nproc_per_node=2 main.py \
|
||||
--device "${DEVICE}" \
|
||||
--epochs "${EPOCHS}" \
|
||||
--steps-per-epoch "${STEPS_PER_EPOCH}" \
|
||||
--batch-size "${BATCH_SIZE}" \
|
||||
--seq-len "${SEQ_LEN}" \
|
||||
--lr "${LR}" \
|
||||
--weight-decay "${WEIGHT_DECAY}" \
|
||||
--dt-min "${DT_MIN}" \
|
||||
--dt-max "${DT_MAX}" \
|
||||
--dt-alpha "${DT_ALPHA}" \
|
||||
@@ -50,13 +68,25 @@ uv run python main.py \
|
||||
${USE_FLOW_FLAG} \
|
||||
${USE_POS_FLAG} \
|
||||
${USE_DT_FLAG} \
|
||||
--num-classes "${NUM_CLASSES}" \
|
||||
--image-size "${IMAGE_SIZE}" \
|
||||
--channels "${CHANNELS}" \
|
||||
--num-workers "${NUM_WORKERS}" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-split "${DATASET_SPLIT}" \
|
||||
--d-model "${D_MODEL}" \
|
||||
--n-layer "${N_LAYER}" \
|
||||
--d-state "${D_STATE}" \
|
||||
--d-conv "${D_CONV}" \
|
||||
--expand "${EXPAND}" \
|
||||
--headdim "${HEADDIM}" \
|
||||
--chunk-size "${CHUNK_SIZE}" \
|
||||
${USE_RESIDUAL_FLAG} \
|
||||
${USE_DDP_FLAG} \
|
||||
--val-every "${VAL_EVERY}" \
|
||||
--val-samples "${VAL_SAMPLES}" \
|
||||
--val-plot-samples "${VAL_PLOT_SAMPLES}" \
|
||||
--val-samples-per-class "${VAL_SAMPLES_PER_CLASS}" \
|
||||
--val-grid-rows "${VAL_GRID_ROWS}" \
|
||||
--val-max-steps "${VAL_MAX_STEPS}" \
|
||||
--center-min "${CENTER_MIN}" \
|
||||
--center-max "${CENTER_MAX}" \
|
||||
--center-distance-min "${CENTER_DISTANCE_MIN}" \
|
||||
--project "${PROJECT}" \
|
||||
--run-name "${RUN_NAME}" \
|
||||
--output-dir "${OUTPUT_DIR}"
|
||||
|
||||
Reference in New Issue
Block a user