From 4cd33258d257236a707ad344f023bbe57cc0ec15 Mon Sep 17 00:00:00 2001 From: Logic Date: Fri, 17 Apr 2026 20:11:54 +0800 Subject: [PATCH] fix: support dual decoder relaunch and width-adaptive lewm tokens --- roboimi/vla/agent_imf.py | 15 +++++++++++---- .../lewm_resnet_dual_decoder_imf_attnres.yaml | 2 +- tests/test_imf_vla_agent.py | 4 ++++ 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/roboimi/vla/agent_imf.py b/roboimi/vla/agent_imf.py index 5a6fa93..2cbdc4b 100644 --- a/roboimi/vla/agent_imf.py +++ b/roboimi/vla/agent_imf.py @@ -277,16 +277,23 @@ class IMFVLAAgent(VLAAgent): if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim: patched = current_tensor.clone() + overlap_slices = tuple( + slice(0, min(src_dim, cur_dim)) + for src_dim, cur_dim in zip(source_tensor.shape, current_tensor.shape) + ) + patched[overlap_slices] = source_tensor[overlap_slices] if key == query_key: copy_count = min(source_tensor.shape[1], current_tensor.shape[1]) - patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...] if copy_count < current_tensor.shape[1] and copy_count > 0: - patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...] + tail = source_tensor[:, copy_count - 1:copy_count, ...] + feature_dim = min(tail.shape[-1], patched.shape[-1]) + patched[:, copy_count:, :feature_dim] = tail[:, :, :feature_dim] else: copy_count = min(source_tensor.shape[1], current_tensor.shape[1]) - patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...] if copy_count < current_tensor.shape[1] and copy_count > 0: - patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...] + tail = source_tensor[:, copy_count - 1:copy_count, ...] + feature_dim = min(tail.shape[-1], patched.shape[-1]) + patched[:, copy_count:, :feature_dim] = tail[:, :, :feature_dim] adapted_state_dict[key] = patched loaded_keys.append(key) continue diff --git a/roboimi/vla/conf/agent/lewm_resnet_dual_decoder_imf_attnres.yaml b/roboimi/vla/conf/agent/lewm_resnet_dual_decoder_imf_attnres.yaml index 3437fbb..f58abd1 100644 --- a/roboimi/vla/conf/agent/lewm_resnet_dual_decoder_imf_attnres.yaml +++ b/roboimi/vla/conf/agent/lewm_resnet_dual_decoder_imf_attnres.yaml @@ -62,7 +62,7 @@ future_decoder: input_dim: ${agent.cond_projector.output_dim} output_dim: ${agent.cond_projector.output_dim} horizon: ${len:${agent.lewm_query_offsets}} - n_obs_steps: ${agent.obs_horizon} + n_obs_steps: ${agent.lewm_history_horizon} cond_dim: ${agent.cond_projector.output_dim} n_emb: 384 causal_attn: false diff --git a/tests/test_imf_vla_agent.py b/tests/test_imf_vla_agent.py index b14112b..34b422d 100644 --- a/tests/test_imf_vla_agent.py +++ b/tests/test_imf_vla_agent.py @@ -1267,6 +1267,10 @@ class IMFVLAAgentTest(unittest.TestCase): self.assertEqual(agent.condition_sequence_length, agent.obs_horizon) self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.obs_horizon) self.assertEqual(agent.future_decoder.constructor_kwargs['cond_dim'], 288) + self.assertEqual( + agent.future_decoder.constructor_kwargs['n_obs_steps'], + agent.lewm_history_horizon, + ) self.assertEqual(agent.future_query_tokens.shape, (1, 1, 288))