81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
import os
|
|
import torch
|
|
from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \
|
|
Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \
|
|
Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN, MambaSimple, TemporalFusionTransformer, SCINet, PAttn, TimeXer, \
|
|
WPMixer, MultiPatchFormer, xPatch_SparseChannel
|
|
|
|
|
|
class Exp_Basic(object):
|
|
def __init__(self, args):
|
|
self.args = args
|
|
self.model_dict = {
|
|
'TimesNet': TimesNet,
|
|
'Autoformer': Autoformer,
|
|
'Transformer': Transformer,
|
|
'Nonstationary_Transformer': Nonstationary_Transformer,
|
|
'DLinear': DLinear,
|
|
'FEDformer': FEDformer,
|
|
'Informer': Informer,
|
|
'LightTS': LightTS,
|
|
'Reformer': Reformer,
|
|
'ETSformer': ETSformer,
|
|
'PatchTST': PatchTST,
|
|
'Pyraformer': Pyraformer,
|
|
'MICN': MICN,
|
|
'Crossformer': Crossformer,
|
|
'FiLM': FiLM,
|
|
'iTransformer': iTransformer,
|
|
'Koopa': Koopa,
|
|
'TiDE': TiDE,
|
|
'FreTS': FreTS,
|
|
'MambaSimple': MambaSimple,
|
|
'TimeMixer': TimeMixer,
|
|
'TSMixer': TSMixer,
|
|
'SegRNN': SegRNN,
|
|
'TemporalFusionTransformer': TemporalFusionTransformer,
|
|
"SCINet": SCINet,
|
|
'PAttn': PAttn,
|
|
'TimeXer': TimeXer,
|
|
'WPMixer': WPMixer,
|
|
'MultiPatchFormer': MultiPatchFormer,
|
|
'xPatch_SparseChannel': xPatch_SparseChannel
|
|
}
|
|
if args.model == 'Mamba':
|
|
print('Please make sure you have successfully installed mamba_ssm')
|
|
from models import Mamba
|
|
self.model_dict['Mamba'] = Mamba
|
|
|
|
self.device = self._acquire_device()
|
|
self.model = self._build_model().to(self.device)
|
|
|
|
def _build_model(self):
|
|
raise NotImplementedError
|
|
return None
|
|
|
|
def _acquire_device(self):
|
|
if self.args.use_gpu and self.args.gpu_type == 'cuda':
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(
|
|
self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
|
|
device = torch.device('cuda:{}'.format(self.args.gpu))
|
|
print('Use GPU: cuda:{}'.format(self.args.gpu))
|
|
elif self.args.use_gpu and self.args.gpu_type == 'mps':
|
|
device = torch.device('mps')
|
|
print('Use GPU: mps')
|
|
else:
|
|
device = torch.device('cpu')
|
|
print('Use CPU')
|
|
return device
|
|
|
|
def _get_data(self):
|
|
pass
|
|
|
|
def vali(self):
|
|
pass
|
|
|
|
def train(self):
|
|
pass
|
|
|
|
def test(self):
|
|
pass
|