159 Views
May 28, 26
スライド概要
DL輪読会資料
DEEP LEARNING JP [DL Papers] “Causal-JEPA: Learning World Models through ObjectLevel Latent Interventions” (ICML 2026) JianqiYang (Matsuo-Iwasawa Lab, M1) http://deeplearning.jp/
Bibliography Information Title Causal-JEPA: Learning World Models through Object-Level Latent Interventions Authors Nam (Brown), Le Lidec (NYU), Maes (Mila), LeCun (NYU), Balestriero (Brown) Affiliations Brown University, New York University, Mila Publication ICML 2026 Arxiv https://arxiv.org/abs/2602.11389 Summary • Introduces Causal-JEPA, extending Joint-Embedding Predictive Architectures with object-level latent interventions to learn causal structure in world models. • Demonstrates improved generalization on counterfactual reasoning and downstream manipulation tasks compared to standard JEPA baselines.
Problem: What World Models Need / Object-centric Alone is Not Sufficient • World models need to predict how things interact, not just how they move • Object-centric representations help ̶ but on their own, they're not enough • Without an explicit interaction signal, models cheat by tracking each object alone • Question: can the learning objective itself force interaction reasoning?
Limitations of Patch Masking and C-JEPA's Approach • Standard patch masking (MAE, V-JEPA…) learns local pixel correlations • Nothing forces objects to "talk to each other" • C-JEPA's fix: mask whole objects ̶ must infer them from others (figure on the right)
C-JEPA Pipeline Overview • Frozen encoder turns each frame into object slots • Mask selected object slots + all future frames • Predictor recovers the masked tokens • Actions / proprioception ride along as separate entities
Object-Level Masking • Hide selected object slots across the entire history • Keep only the first frame as an "identity anchor" ̶ so the model knows which object is which • Reading: "what would this scene look like if I couldn't see object X?"
Learning Objective • Predictor: bidirectional masked transformer (not autoregressive) • Total loss decomposes into two terms : • History term: recover masked objects from the OTHERS – Kills the self-dynamics shortcut — can't just propagate the object's own past • Future term: standard forward prediction • Together: interaction reasoning becomes functionally necessary for low loss
Why "Causal"? • "Causal" here means predictive dependencies that survive masking • NOT causal identification, NOT do-calculus, NOT a graph • Masking acts as a latent intervention on what the model can see • → a causal inductive bias baked into the loss
Experiment 1: CLEVRER / Visual Reasoning • CLEVRER: synthetic videos with descriptive / predictive / explanatory / counterfactual questions • Rollout 128 → 160 frames, then answer questions via ALOE • Counterfactual: 47.7 → 68.8 (+21 absolute) • Descriptive barely moves (+3%) → masking specifically targets counterfactuals 9
Experiment 2: Push-T / Predictive Control Experimental Setup • Push-T (Chi+ 2024): contact-rich manipulation (push the green T into the target pose of the gray T) • Planning: finite-horizon optimal control + CEM + MPC in the latent object-centric state space • N=4 slots; 50 env steps/episode; goal = state 25 steps later Key Results (read top-to-bottom in the table above) • Patch baselines (196×384 tokens, top of table): DINO-WM 91.33%, DINO-WM-Reg. 88.00% (register variant ≈ same) • Object-centric models (6×128 = 1.02% tokens): OC-DINO-WM 60.67 → OC-JEPA 76.00 (+15.33 from JEPA) → C-JEPA 88.67 (+28.00 with masking) • → C-JEPA matches the patch baseline (91.33%) using only 1.02% of tokens, with planning 8.6× faster (5,763s → 673s on a single L40s GPU, 50 trajectories) 10
Auxiliary Variables + Masking Strategy Ablation (Figure 3) • Auxiliary inputs – Treating actions / proprioception as separate entities beats concatenation • Masking unit ablation – Object-level masking beats token-level and tube-level on interaction-heavy tasks (Appendix J)
Theory: Influence Neighborhood • 4 Assumptions: temporal direction · shared transition · object-aligned slots · finite history – (no causal sufficiency assumed → unobserved confounders allowed) • Definition 1 — Influence Neighborhood N_t(i): minimal sufficient set to recover masked object • Theorem 1 — Any MSE-optimal predictor MUST use information in N_t(i) • Corollary 1: Repeated masking → attention patterns align with N_t(i) [~ Invariant Causal Prediction / IRM] • Markov blanket analog (Appendix B.2); practical alternative to full causal discovery
Conclusion and Limitations • • Contributions – First integration of JEPA + object-centric world modeling – Object masking = latent intervention → causal inductive bias in the loss – +21% counterfactual gain, 8.6× faster planning, no reconstruction loss needed Limitations (see figure below) – Encoder quality caps performance; no causal validation; only simple benchmarks