You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It would be highly valuable to extend the Transformers library with an implementation of the Titans model—a hybrid architecture that combines traditional attention-based processing with a dedicated long-term memory module (for test-time memorization) and fine-tuning using a Group Relative Policy Optimization (GRPO) method. This approach would allow LLMs to better handle extremely long contexts and improve chain-of-thought reasoning by dynamically adapting their memory during inference while being fine-tuned with reinforcement learning techniques.
Motivation and Rationale:
Enhanced Long-Context Modeling:
The Titans architecture integrates a neural long-term memory module that learns to store, update, and selectively forget information based on a “surprise” metric (e.g., gradient magnitude). This mimics human long-term memory and overcomes the quadratic complexity limitation of traditional attention for long sequences.
Adaptive Test-Time Learning:
By learning to memorize at test time, the model can update its context representation on the fly, allowing for more robust reasoning in tasks with millions of tokens.
Reinforcement Learning Fine-Tuning via GRPO:
The GRPO method, a variant of PPO, uses group-based advantage estimates and ratio clipping to stabilize policy updates. Incorporating this into Transformers would allow for more efficient fine-tuning, reducing reliance on extensive supervised datasets and improving chain-of-thought outputs.
Proposed Implementation:
Titans Architecture:
Introduce a new model class (e.g., TitansForCausalLM) that wraps a standard Transformer with an additional long-term memory module.
The module should accept token embeddings and update a memory vector using an MLP with momentum-based updates and an adaptive forgetting gate.
Incorporate a set of persistent memory tokens (learnable parameters) that are concatenated with the Transformer’s output before the final prediction layer.
GRPO Fine-Tuning:
Create a custom trainer (e.g., subclassing TRL’s PPOTrainer) that overrides the loss computation to implement GRPO.
The loss should compute token-level log probabilities from both a reference (old) policy and the updated policy, compute the probability ratio, and then apply clipping based on a configurable epsilon value.
Integrate a dummy or real KL penalty term to control deviations between policies.
Integration with TRL:
Provide example scripts demonstrating the fine-tuning loop using TRL’s APIs with the custom GRPO loss function.
Update documentation and examples to guide users on how to apply this technique to long-context reasoning tasks.
Implementing Titans with GRPO fine-tuning can potentially revolutionize how we approach long-context learning and chain-of-thought reasoning. Several recent research efforts (e.g., the Titans paper [arXiv:2501.00663] and DeepSeek's work) have demonstrated promising results with these techniques. An open implementation in Transformers would help the community experiment with these ideas and possibly drive further research in scalable and adaptive LLMs.
hi @rajveer43, we generally don't add model architectures until a pretrained model using that architecture exists! Is there a relevant pretrained Titans model that people are using?
Model description
It would be highly valuable to extend the Transformers library with an implementation of the Titans model—a hybrid architecture that combines traditional attention-based processing with a dedicated long-term memory module (for test-time memorization) and fine-tuning using a Group Relative Policy Optimization (GRPO) method. This approach would allow LLMs to better handle extremely long contexts and improve chain-of-thought reasoning by dynamically adapting their memory during inference while being fine-tuned with reinforcement learning techniques.
Motivation and Rationale:
Enhanced Long-Context Modeling:
The Titans architecture integrates a neural long-term memory module that learns to store, update, and selectively forget information based on a “surprise” metric (e.g., gradient magnitude). This mimics human long-term memory and overcomes the quadratic complexity limitation of traditional attention for long sequences.
Adaptive Test-Time Learning:
By learning to memorize at test time, the model can update its context representation on the fly, allowing for more robust reasoning in tasks with millions of tokens.
Reinforcement Learning Fine-Tuning via GRPO:
The GRPO method, a variant of PPO, uses group-based advantage estimates and ratio clipping to stabilize policy updates. Incorporating this into Transformers would allow for more efficient fine-tuning, reducing reliance on extensive supervised datasets and improving chain-of-thought outputs.
Proposed Implementation:
Titans Architecture:
Introduce a new model class (e.g., TitansForCausalLM) that wraps a standard Transformer with an additional long-term memory module.
The module should accept token embeddings and update a memory vector using an MLP with momentum-based updates and an adaptive forgetting gate.
Incorporate a set of persistent memory tokens (learnable parameters) that are concatenated with the Transformer’s output before the final prediction layer.
GRPO Fine-Tuning:
Create a custom trainer (e.g., subclassing TRL’s PPOTrainer) that overrides the loss computation to implement GRPO.
The loss should compute token-level log probabilities from both a reference (old) policy and the updated policy, compute the probability ratio, and then apply clipping based on a configurable epsilon value.
Integrate a dummy or real KL penalty term to control deviations between policies.
Integration with TRL:
Provide example scripts demonstrating the fine-tuning loop using TRL’s APIs with the custom GRPO loss function.
Update documentation and examples to guide users on how to apply this technique to long-context reasoning tasks.
Environment:
Transformers version: (latest)
Python version: 3.8+
Additional libraries: TRL (for PPOTrainer extension), PyTorch
Implementing Titans with GRPO fine-tuning can potentially revolutionize how we approach long-context learning and chain-of-thought reasoning. Several recent research efforts (e.g., the Titans paper [arXiv:2501.00663] and DeepSeek's work) have demonstrated promising results with these techniques. An open implementation in Transformers would help the community experiment with these ideas and possibly drive further research in scalable and adaptive LLMs.
Open source status
Provide useful links for the implementation
https://arxiv.org/html/2501.00663v1
https://github.com/rajveer43/titan_transformer
The text was updated successfully, but these errors were encountered: