feat: add PEMS and Solar dataset support
- Add Dataset_PEMS and Dataset_Solar classes for PEMS and Solar datasets - Update data_factory.py to include new dataset mappings - Fix M4 dataset handling with proper numpy array dtype - Add PEMS-specific loss function (L1Loss) and inverse transform support - Update validation logic for PEMS dataset with inverse scaling - Fix M4 data loader insample mask calculation bug Changes support new traffic and solar energy datasets while maintaining backward compatibility with existing datasets.
This commit is contained in:
@ -34,9 +34,19 @@ class Exp_Long_Term_Forecast(Exp_Basic):
|
||||
model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
|
||||
return model_optim
|
||||
|
||||
def _select_criterion(self):
|
||||
criterion = nn.MSELoss()
|
||||
return criterion
|
||||
def _select_criterion(self, loss_name='MSE'):
|
||||
if self.args.data == 'PEMS':
|
||||
return nn.L1Loss()
|
||||
elif loss_name == 'MSE':
|
||||
return nn.MSELoss()
|
||||
elif loss_name == 'MAPE':
|
||||
return mape_loss()
|
||||
elif loss_name == 'MASE':
|
||||
return mase_loss()
|
||||
elif loss_name == 'SMAPE':
|
||||
return smape_loss()
|
||||
elif loss_name == 'MAE':
|
||||
return nn.L1Loss(reduction='mean')
|
||||
|
||||
|
||||
def vali(self, vali_data, vali_loader, criterion):
|
||||
@ -66,9 +76,18 @@ class Exp_Long_Term_Forecast(Exp_Basic):
|
||||
pred = outputs.detach()
|
||||
true = batch_y.detach()
|
||||
|
||||
loss = criterion(pred, true)
|
||||
if self.args.data == 'PEMS':
|
||||
B, T, C = pred.shape
|
||||
pred = pred.cpu().numpy()
|
||||
true = true.cpu().numpy()
|
||||
pred = vali_data.inverse_transform(pred.reshape(-1, C)).reshape(B, T, C)
|
||||
true = vali_data.inverse_transform(true.reshape(-1, C)).reshape(B, T, C)
|
||||
mae, mse, rmse, mape, mspe = metric(pred, true)
|
||||
total_loss.append(mae)
|
||||
else:
|
||||
loss = criterion(pred, true)
|
||||
total_loss.append(loss.item())
|
||||
|
||||
total_loss.append(loss.item())
|
||||
total_loss = np.average(total_loss)
|
||||
self.model.train()
|
||||
return total_loss
|
||||
|
Reference in New Issue
Block a user