Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Titans Architecture with GRPO Fine-Tuning #36352

Open
2 tasks
rajveer43 opened this issue Feb 23, 2025 · 2 comments
Open
2 tasks

Implement Titans Architecture with GRPO Fine-Tuning #36352

rajveer43 opened this issue Feb 23, 2025 · 2 comments

Comments

@rajveer43
Copy link
Contributor

rajveer43 commented Feb 23, 2025

Model description

It would be highly valuable to extend the Transformers library with an implementation of the Titans model—a hybrid architecture that combines traditional attention-based processing with a dedicated long-term memory module (for test-time memorization) and fine-tuning using a Group Relative Policy Optimization (GRPO) method. This approach would allow LLMs to better handle extremely long contexts and improve chain-of-thought reasoning by dynamically adapting their memory during inference while being fine-tuned with reinforcement learning techniques.

Motivation and Rationale:

Enhanced Long-Context Modeling:
The Titans architecture integrates a neural long-term memory module that learns to store, update, and selectively forget information based on a “surprise” metric (e.g., gradient magnitude). This mimics human long-term memory and overcomes the quadratic complexity limitation of traditional attention for long sequences.
Adaptive Test-Time Learning:
By learning to memorize at test time, the model can update its context representation on the fly, allowing for more robust reasoning in tasks with millions of tokens.
Reinforcement Learning Fine-Tuning via GRPO:
The GRPO method, a variant of PPO, uses group-based advantage estimates and ratio clipping to stabilize policy updates. Incorporating this into Transformers would allow for more efficient fine-tuning, reducing reliance on extensive supervised datasets and improving chain-of-thought outputs.

Proposed Implementation:

Titans Architecture:

Introduce a new model class (e.g., TitansForCausalLM) that wraps a standard Transformer with an additional long-term memory module.
The module should accept token embeddings and update a memory vector using an MLP with momentum-based updates and an adaptive forgetting gate.
Incorporate a set of persistent memory tokens (learnable parameters) that are concatenated with the Transformer’s output before the final prediction layer.
GRPO Fine-Tuning:

Create a custom trainer (e.g., subclassing TRL’s PPOTrainer) that overrides the loss computation to implement GRPO.
The loss should compute token-level log probabilities from both a reference (old) policy and the updated policy, compute the probability ratio, and then apply clipping based on a configurable epsilon value.
Integrate a dummy or real KL penalty term to control deviations between policies.
Integration with TRL:

Provide example scripts demonstrating the fine-tuning loop using TRL’s APIs with the custom GRPO loss function.
Update documentation and examples to guide users on how to apply this technique to long-context reasoning tasks.

Environment:

Transformers version: (latest)
Python version: 3.8+
Additional libraries: TRL (for PPOTrainer extension), PyTorch

Implementing Titans with GRPO fine-tuning can potentially revolutionize how we approach long-context learning and chain-of-thought reasoning. Several recent research efforts (e.g., the Titans paper [arXiv:2501.00663] and DeepSeek's work) have demonstrated promising results with these techniques. An open implementation in Transformers would help the community experiment with these ideas and possibly drive further research in scalable and adaptive LLMs.

Open source status

  • The model implementation is available
  • The model weights are available

Provide useful links for the implementation

https://arxiv.org/html/2501.00663v1
https://github.com/rajveer43/titan_transformer

@marthos1
Copy link

모두 감사합니다

@Rocketknight1
Copy link
Member

hi @rajveer43, we generally don't add model architectures until a pretrained model using that architecture exists! Is there a relevant pretrained Titans model that people are using?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants