36 lines
1018 B
Python
36 lines
1018 B
Python
from typing import Dict
|
|
import torch
|
|
import torch.nn as nn
|
|
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
|
from diffusion_policy.model.common.normalizer import LinearNormalizer
|
|
|
|
class BaseLowdimPolicy(ModuleAttrMixin):
|
|
# ========= inference ============
|
|
# also as self.device and self.dtype for inference device transfer
|
|
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
obs_dict:
|
|
obs: B,To,Do
|
|
return:
|
|
action: B,Ta,Da
|
|
To = 3
|
|
Ta = 4
|
|
T = 6
|
|
|o|o|o|
|
|
| | |a|a|a|a|
|
|
|o|o|
|
|
| |a|a|a|a|a|
|
|
| | | | |a|a|
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
# reset state for stateful policies
|
|
def reset(self):
|
|
pass
|
|
|
|
# ========== training ===========
|
|
# no standard training interface except setting normalizer
|
|
def set_normalizer(self, normalizer: LinearNormalizer):
|
|
raise NotImplementedError()
|
|
|
|
|