code release
This commit is contained in:
25
diffusion_policy/policy/base_image_policy.py
Normal file
25
diffusion_policy/policy/base_image_policy.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Dict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
||||
from diffusion_policy.model.common.normalizer import LinearNormalizer
|
||||
|
||||
class BaseImagePolicy(ModuleAttrMixin):
|
||||
# init accepts keyword argument shape_meta, see config/task/*_image.yaml
|
||||
|
||||
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
obs_dict:
|
||||
str: B,To,*
|
||||
return: B,Ta,Da
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
# reset state for stateful policies
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
# ========== training ===========
|
||||
# no standard training interface except setting normalizer
|
||||
def set_normalizer(self, normalizer: LinearNormalizer):
|
||||
raise NotImplementedError()
|
||||
Reference in New Issue
Block a user