fix: support dual decoder relaunch and width-adaptive lewm tokens

This commit is contained in:
Logic
2026-04-17 20:11:54 +08:00
parent d8066823e2
commit 4cd33258d2
3 changed files with 16 additions and 5 deletions

View File

@@ -277,16 +277,23 @@ class IMFVLAAgent(VLAAgent):
if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim: if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim:
patched = current_tensor.clone() patched = current_tensor.clone()
overlap_slices = tuple(
slice(0, min(src_dim, cur_dim))
for src_dim, cur_dim in zip(source_tensor.shape, current_tensor.shape)
)
patched[overlap_slices] = source_tensor[overlap_slices]
if key == query_key: if key == query_key:
copy_count = min(source_tensor.shape[1], current_tensor.shape[1]) copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
if copy_count < current_tensor.shape[1] and copy_count > 0: if copy_count < current_tensor.shape[1] and copy_count > 0:
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...] tail = source_tensor[:, copy_count - 1:copy_count, ...]
feature_dim = min(tail.shape[-1], patched.shape[-1])
patched[:, copy_count:, :feature_dim] = tail[:, :, :feature_dim]
else: else:
copy_count = min(source_tensor.shape[1], current_tensor.shape[1]) copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
if copy_count < current_tensor.shape[1] and copy_count > 0: if copy_count < current_tensor.shape[1] and copy_count > 0:
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...] tail = source_tensor[:, copy_count - 1:copy_count, ...]
feature_dim = min(tail.shape[-1], patched.shape[-1])
patched[:, copy_count:, :feature_dim] = tail[:, :, :feature_dim]
adapted_state_dict[key] = patched adapted_state_dict[key] = patched
loaded_keys.append(key) loaded_keys.append(key)
continue continue

View File

@@ -62,7 +62,7 @@ future_decoder:
input_dim: ${agent.cond_projector.output_dim} input_dim: ${agent.cond_projector.output_dim}
output_dim: ${agent.cond_projector.output_dim} output_dim: ${agent.cond_projector.output_dim}
horizon: ${len:${agent.lewm_query_offsets}} horizon: ${len:${agent.lewm_query_offsets}}
n_obs_steps: ${agent.obs_horizon} n_obs_steps: ${agent.lewm_history_horizon}
cond_dim: ${agent.cond_projector.output_dim} cond_dim: ${agent.cond_projector.output_dim}
n_emb: 384 n_emb: 384
causal_attn: false causal_attn: false

View File

@@ -1267,6 +1267,10 @@ class IMFVLAAgentTest(unittest.TestCase):
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon) self.assertEqual(agent.condition_sequence_length, agent.obs_horizon)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.obs_horizon) self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.obs_horizon)
self.assertEqual(agent.future_decoder.constructor_kwargs['cond_dim'], 288) self.assertEqual(agent.future_decoder.constructor_kwargs['cond_dim'], 288)
self.assertEqual(
agent.future_decoder.constructor_kwargs['n_obs_steps'],
agent.lewm_history_horizon,
)
self.assertEqual(agent.future_query_tokens.shape, (1, 1, 288)) self.assertEqual(agent.future_query_tokens.shape, (1, 1, 288))