feat(vla): align transformer training stack and rollout validation
This commit is contained in:
28
tests/test_eval_vla_execution.py
Normal file
28
tests/test_eval_vla_execution.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import unittest
|
||||
|
||||
from roboimi.vla.eval_utils import execute_policy_action
|
||||
|
||||
|
||||
class _FakeEnv:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def step(self, action):
|
||||
self.calls.append(("step", action))
|
||||
|
||||
def step_jnt(self, action):
|
||||
self.calls.append(("step_jnt", action))
|
||||
|
||||
|
||||
class EvalVLAExecutionTest(unittest.TestCase):
|
||||
def test_execute_policy_action_uses_ee_step(self):
|
||||
env = _FakeEnv()
|
||||
action = [1, 2, 3]
|
||||
|
||||
execute_policy_action(env, action)
|
||||
|
||||
self.assertEqual(env.calls, [("step", action)])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user