fix: support dual decoder relaunch and width-adaptive lewm tokens
This commit is contained in:
@@ -277,16 +277,23 @@ class IMFVLAAgent(VLAAgent):
|
||||
|
||||
if key in {query_key, pos_key} and source_tensor.ndim == current_tensor.ndim:
|
||||
patched = current_tensor.clone()
|
||||
overlap_slices = tuple(
|
||||
slice(0, min(src_dim, cur_dim))
|
||||
for src_dim, cur_dim in zip(source_tensor.shape, current_tensor.shape)
|
||||
)
|
||||
patched[overlap_slices] = source_tensor[overlap_slices]
|
||||
if key == query_key:
|
||||
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
|
||||
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
|
||||
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
||||
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
|
||||
tail = source_tensor[:, copy_count - 1:copy_count, ...]
|
||||
feature_dim = min(tail.shape[-1], patched.shape[-1])
|
||||
patched[:, copy_count:, :feature_dim] = tail[:, :, :feature_dim]
|
||||
else:
|
||||
copy_count = min(source_tensor.shape[1], current_tensor.shape[1])
|
||||
patched[:, :copy_count, ...] = source_tensor[:, :copy_count, ...]
|
||||
if copy_count < current_tensor.shape[1] and copy_count > 0:
|
||||
patched[:, copy_count:, ...] = source_tensor[:, copy_count - 1:copy_count, ...]
|
||||
tail = source_tensor[:, copy_count - 1:copy_count, ...]
|
||||
feature_dim = min(tail.shape[-1], patched.shape[-1])
|
||||
patched[:, copy_count:, :feature_dim] = tail[:, :, :feature_dim]
|
||||
adapted_state_dict[key] = patched
|
||||
loaded_keys.append(key)
|
||||
continue
|
||||
|
||||
@@ -62,7 +62,7 @@ future_decoder:
|
||||
input_dim: ${agent.cond_projector.output_dim}
|
||||
output_dim: ${agent.cond_projector.output_dim}
|
||||
horizon: ${len:${agent.lewm_query_offsets}}
|
||||
n_obs_steps: ${agent.obs_horizon}
|
||||
n_obs_steps: ${agent.lewm_history_horizon}
|
||||
cond_dim: ${agent.cond_projector.output_dim}
|
||||
n_emb: 384
|
||||
causal_attn: false
|
||||
|
||||
@@ -1267,6 +1267,10 @@ class IMFVLAAgentTest(unittest.TestCase):
|
||||
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.obs_horizon)
|
||||
self.assertEqual(agent.future_decoder.constructor_kwargs['cond_dim'], 288)
|
||||
self.assertEqual(
|
||||
agent.future_decoder.constructor_kwargs['n_obs_steps'],
|
||||
agent.lewm_history_horizon,
|
||||
)
|
||||
self.assertEqual(agent.future_query_tokens.shape, (1, 1, 288))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user