diff --git a/utils/tools.py b/utils/tools.py index 6efa712..701d0b6 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -22,6 +22,11 @@ def adjust_learning_rate(optimizer, epoch, args): lr_adjust = {epoch: args.learning_rate if epoch < 3 else args.learning_rate * (0.9 ** ((epoch - 3) // 1))} elif args.lradj == "cosine": lr_adjust = {epoch: args.learning_rate /2 * (1 + math.cos(epoch / args.train_epochs * math.pi))} + elif args.lradj == 'sigmoid': + k = 0.5 # logistic growth rate + s = 10 # decreasing curve smoothing rate + w = 10 # warm-up coefficient + lr_adjust = {epoch: args.learning_rate / (1 + np.exp(-k * (epoch - w))) - args.learning_rate / (1 + np.exp(-k/s * (epoch - w*s)))} if epoch in lr_adjust.keys(): lr = lr_adjust[epoch] for param_group in optimizer.param_groups: