Skip to content

wadeKeith/DeepThinkVLA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

55 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

DeepThinkVLA hero

πŸ”₯ DeepThinkVLA πŸ”₯

Enhancing Reasoning Capability of Vision-Language-Action Models

arXiv Paper Hugging Face Weights Dataset

DeepThinkVLA: Enhancing Reasoning Capability of Vision-Language-Action Models

πŸ”— Quick Links

πŸ“ TODO

  • LIBERO benchmark
  • RobotWin benchmark
  • Real-world hardware experiments

🧠 Overview

DeepThinkVLA rethinks Vision-Language-Action (VLA) policies with explicit deliberation. Starting from the public pi0-FAST checkpoint, we refactor the policy into a 2.9B parameter hybrid decoder that writes a reasoning trace before emitting action chunks. The accompanying paper combines embodied Chain-of-Thought (CoT) supervised fine-tuning with outcome-driven reinforcement learning, yielding a 97.0% average success rate across the LIBERO benchmark (Object 99.0, Spatial 96.6, Goal 96.4, Long 96.2). The hybrid architecture alone lifts success by 15.5 percentage points over a naive autoregressive CoT variant, and the RL refinement supplies the final +2.0 point boost on LIBERO-Long.

✨ Highlights

  • Hybrid attention decoder cleanly separates autoregressive reasoning from parallel action generation, closing the latency gap while keeping control precise.
  • Two-stage CoT data engine distills key frames with a cloud LVLM and scales to full trajectories via a fine-tuned local VLM.
  • Outcome-based RL with grouped credit assignment aligns the full think-act sequence and stabilizes updates with KL regularization to the SFT policy.
  • Masked-CoT(DeepThinkVLA) inference preserves accuracy (96.5% average SR) while running 0.175x the latency of pi0-FAST(Autoregressive), whereas random CoT quickly degrades performance (85.1%).

πŸ—οΈ Architecture

Hybrid attention architecture

DeepThinkVLA inserts a <think> segment between observations and actions. Reasoning tokens are generated autoregressively, after which the decoder switches to bidirectional attention to emit action vectors in parallel. This resolves the modality conflict that limits single-decoder baselines and enables efficient rollouts for downstream reinforcement learning.

πŸ“¦ Embodied CoT Dataset

Two-stage CoT curation

A scalable annotation pipeline supplies paired reasoning/action traces:

  • Stage 1 isolates key frames via gripper-state heuristics, queries a cloud LVLM for high-quality CoT, and performs targeted human review.
  • Stage 2 fine-tunes a local VLM on those exemplars and auto-labels the remaining frames, applying schema and temporal checks to keep trajectories coherent.

πŸ”„ Training Pipeline

Two-stage training with RL alignment

Training proceeds in two stages:

  • SFT cold start: token-level cross-entropy teaches the hybrid decoder to produce well-formed CoT and aligned actions under causal/bidirectional masks.
  • Outcome-driven RL: grouped reinforcement policy optimization (GRPO) standardizes sparse rewards inside task-conditioned batches, while a KL penalty to the SFT policy prevents drift. The RL stage adds +2.0 SR on LIBERO-Long and strengthens the causal link between thought and action.

πŸ“Š Performance

Effect of RL and architecture choices

  • DeepThinkVLA reaches a 97.0% average success rate across LIBERO, outperforming autoregressive, diffusion, and parallel-decoding baselines under the single-model protocol.
  • RL-over-SFT lifts LIBERO-Long from 94.2% to 96.2% without extra demonstrations, demonstrating recoveries on long-horizon tasks.
  • The hybrid decoder outperforms the naive autoregressive CoT variant by 15.5 points and keeps latency manageable; Mask CoT inference keeps accuracy while running 0.175x pi0-FAST latency.

🎬 Qualitative Behavior

Reasoning-enabled recovery Deliberate reasoning enables self-correction: when the robot drops an object, CoT-aware decoding identifies the mistake and guides a recovery action, whereas the reactive baseline stalls.

πŸ› οΈ Setup

Tested on Linux/WSL with NVIDIA GPUs (CUDA 12.x) and Python >= 3.10. Full SFT typically requires >= 8x80GB GPUs; RL runs assume a multi-node setup similar to scripts/run_deepthinkvla_rl.sh.

conda create -n deepthinkvla python=3.10 -y
conda activate deepthinkvla
pip install -r requirements.txt

If installation fails with egl_probe, install cmake==3.31.6, fetch the patched wheel, and retry:

pip install cmake==3.31.6
wget https://github.com/mhandb/egl_probe/archive/fix_windows_build.zip
pip install fix_windows_build.zip
pip install -r requirements.txt

Configure optional logging backends (Weights & Biases, SwanLab) before launching experiments.

πŸ’Ύ Data & Checkpoints

  1. LIBERO CoT demonstrations (paper Sec. 3.2):
    bash data/download_libero_cot.sh data/datasets/yinchenghust/libero_cot yinchenghust/libero_cot
  2. LIBERO simulation dataset:
    huggingface-cli download --repo-type dataset --resume-download yifengzhu-hf/LIBERO-datasets --local-dir ./src/libero/datasets/
  3. Base model weights:
    huggingface-cli download --repo-type model \
        --resume-download yinchenghust/deepthinkvla_base \
        --local-dir yinchenghust/deepthinkvla_base/
  4. Released SFT checkpoints:
    huggingface-cli download --repo-type model \
        --resume-download yinchenghust/deepthinkvla_libero_cot_sft \
        --local-dir yinchenghust/deepthinkvla_libero_cot_sft/
  5. Released SFT+RL checkpoints:
    huggingface-cli download --repo-type model \
        --resume-download yinchenghust/deepthinkvla_libero_cot_rl \
        --local-dir yinchenghust/deepthinkvla_libero_cot_rl/

Authenticate with huggingface-cli login if assets are private.

πŸ§ͺ Experiments

All scripts assume the repository root as the working directory and extend PYTHONPATH to src/.

Supervised fine-tuning (Table 1)

bash scripts/finetune.sh

This expands to:

deepspeed src/train.py \
  --deepspeed ./src/configs/zero2.json \
  --base_model_path <hf_base_model_id_or_local_path> \
  --repo_id <hf_dataset_repo>/libero_cot \
  --output_dir ./checkpoints/sft/deepthinkvla/libero_cot \
  --per_device_train_batch_size 8 \
  --gradient_accumulation_steps 2 \
  --num_images_in_input 2 \
  --report_to none

Key flags: toggle --num_images_in_input for the single-camera variant, adjust --bits, --lora_enable, --vision_lora, and match schedules with --max_steps, --save_steps, and --save_total_limit.

Evaluation

bash scripts/eval.sh \
  --pretrained_checkpoint yinchenghust/deepthinkvla_libero_cot_sft

Add arguments such as --task_suite_name libero_10 to sweep specific task sets.

RL refinement (Table 3)

bash scripts/run_deepthinkvla_rl.sh

Configure LIBERO_CONFIG_PATH, SFT_MODEL_PATH, and hardware settings (NUM_GPUS, NUM_NODES). The trainer (python -m verl.trainer.main_ppo) implements GRPO with sparse success rewards, format regularization, and KL penalties to remain close to the SFT policy.

bash scripts/eval.sh \
  --pretrained_checkpoint yinchenghust/deepthinkvla_libero_cot_rl

Ablations

  • Mask CoT: swap get_vla_action for get_vla_action_mask_cot in src/experiments/run_libero_eval.py to drop reasoning tokens before decoding actions.
  • Random CoT: overwrite cot_text in get_vla_action with sampled tokens to test sensitivity to reasoning quality.

Measure inference latency via python -m experiments.run_libero_eval to reproduce the 0.175x runtime reported for Mask CoT.

πŸ“ Repository Structure

DeepThinkVLA/
β”œβ”€β”€ LICENSE
β”œβ”€β”€ README.md
β”œβ”€β”€ requirements.txt
β”œβ”€β”€ data/                  # Data helpers and CoT acquisition scripts
β”œβ”€β”€ figs/                  # README figures (Fig. 1-5)
β”œβ”€β”€ scripts/               # Launchers for SFT, eval, RL, and alignment
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ configs/           # Hyperparameter dataclasses and DeepSpeed configs
β”‚   β”œβ”€β”€ dt_datasets/       # Dataset wrappers, tokenizers, normalization
β”‚   β”œβ”€β”€ experiments/       # Evaluation utilities and LIBERO runners
β”‚   β”œβ”€β”€ lerobot/           # Third-party LeRobot components
β”‚   β”œβ”€β”€ libero/            # LIBERO simulator assets
β”‚   β”œβ”€β”€ sft/               # Model, trainer, and hybrid attention utilities
β”‚   β”œβ”€β”€ tools/             # Maintenance utilities
β”‚   β”œβ”€β”€ train.py           # SFT entrypoint
β”‚   └── verl/              # VERL PPO stack for RL refinement
└── checkpoints/           # (Generated) model checkpoints

⭐ Star History

Star History Chart

This chart auto-updates hourly via GitHub Actions.

πŸ™ Acknowledgements

DeepThinkVLA builds on open-source components from Hugging Face Transformers, PEFT, DeepSpeed, LeRobot, LIBERO, VERL, SimpleVLA-RL and the broader robotics community. We thank the maintainers of:

πŸ₯° Citation

If you find this repository helpful, please consider citing:

@article{yin2025deepthinkvla,
  title={DeepThinkVLA: Enhancing Reasoning Capability of Vision-Language-Action Models},
  author={Yin, Cheng and Lin, Yankai and Xu, Wang and Tam, Sikyuen and Zeng, Xiangrui and Liu, Zhiyuan and Yin, Zhouping},
  journal={arXiv preprint arXiv:2511.15669},
  year={2025}
}