-
Notifications
You must be signed in to change notification settings - Fork 300
ADD RWKV7 #2421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
ADD RWKV7 #2421
Conversation
Summary of ChangesHello @pass-lin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR introduces the RWKV-7 model, a powerful RNN architecture, to keras_hub
. The contribution is significant and includes the backbone, tokenizer, preprocessor, an incomplete task model, and a checkpoint conversion script. The implementation follows the modular structure of keras_hub
.
However, there are several critical issues that must be addressed before this PR can be merged:
- Missing Tests: The PR lacks unit tests for all new components. According to the contribution guidelines, testing is a mandatory requirement.[^1]
- Incomplete
CausalLM
Task: TheRWKV7CausalLM
task model is a stub withTODO
s, making it non-functional for generation. - Critical Bugs: There are critical bugs in the tokenizer and preprocessor implementations that will cause runtime errors.
- Style Guide Violations: There are numerous style guide violations, including a filename typo, missing docstrings, and inconsistencies with the recommended model input structure.
I've left detailed comments on these issues. Once these are resolved, this will be a great addition to the library.
|
||
def save_assets(self, dir_path): | ||
path = os.path.join(dir_path, VOCAB_FILENAME) | ||
with open(path, "wb") as file: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def call_with_cache( | ||
self, | ||
token_ids, | ||
cache, | ||
cache_update_index, | ||
): | ||
pass # TODO | ||
|
||
def _build_cache(self, token_ids): | ||
pass # TODO | ||
|
||
def generate_step( | ||
self, | ||
inputs, | ||
stop_token_ids=None, | ||
): | ||
pass # TODO | ||
|
||
def score( | ||
self, | ||
token_ids, | ||
padding_mask=None, | ||
scoring_mode="logits", | ||
layer_intercept_fn=None, | ||
target_ids=None, | ||
): | ||
pass # TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from modelscope import snapshot_download | ||
|
||
from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone | ||
from keras_hub.src.models.rwkv7.rwkv7_casual_lm import RWKV7CausalLM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
token_ids, padding_mask = self.packer( | ||
x, sequence_length=sequence_length, add_end_value=False | ||
) | ||
return token_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generate_preprocess
method returns a single tensor, but generate_postprocess
expects a dictionary {'token_ids': ..., 'padding_mask': ...}
. This inconsistency will cause a TypeError
during text generation. generate_preprocess
should return a dictionary to match the expected input of generate_postprocess
and for consistency with the base preprocessor class.1
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}
Style Guide References
Footnotes
-
Preprocessors should handle padding, truncation, generating attention masks, and formatting the output into a dictionary of tensors that match the backbone's input signature. ↩
def tokenize(self, inputs): | ||
self._check_vocabulary() | ||
tokens = self._tokenizer.encode(inputs) | ||
|
||
def tokens2ids(x): | ||
return [self.token_to_id(t) for t in x] | ||
|
||
if is_string_dtype(self.dtype): | ||
if isinstance(inputs, str): | ||
return tokens2ids(tokens) | ||
return [tokens2ids(t) for t in tokens] | ||
return tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for handling dtype='string'
is incorrect. It references a tokens2ids
function which is not in the correct scope and would not work as intended because self._tokenizer.encode()
already returns integer IDs. If dtype
is 'string'
, you should convert the integer IDs back to string tokens.
def tokenize(self, inputs): | |
self._check_vocabulary() | |
tokens = self._tokenizer.encode(inputs) | |
def tokens2ids(x): | |
return [self.token_to_id(t) for t in x] | |
if is_string_dtype(self.dtype): | |
if isinstance(inputs, str): | |
return tokens2ids(tokens) | |
return [tokens2ids(t) for t in tokens] | |
return tokens | |
def tokenize(self, inputs): | |
self._check_vocabulary() | |
ids = self._tokenizer.encode(inputs) | |
if is_string_dtype(self.dtype): | |
def ids_to_tokens(id_list): | |
return [self.id_to_token(i) for i in id_list] | |
if isinstance(inputs, str): | |
return ids_to_tokens(ids) | |
return [ids_to_tokens(id_list) for id_list in ids] | |
return ids |
"keras_hub.models.RWKVTokenizer", | ||
] | ||
) | ||
class RWKVTokenizer(tokenizer.Tokenizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RWKVTokenizer
class is missing a docstring. Please add a comprehensive docstring that explains the tokenizer, its arguments, and includes usage examples, as per the style guide.1
Style Guide References
Footnotes
-
All public classes must have Google-style docstrings with a summary, examples, and documentation for parameters. ↩
- 0.5 | ||
) # soft-clamp to (-inf, -0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
@keras_hub_export("keras_hub.models.RWKV7Backbone") | ||
class RWKV7Backbone(Backbone): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RWKV7Backbone
class is missing a docstring. Please add a Google-style docstring explaining the model's architecture, its parameters, and include a usage example, as specified in the style guide.1
Style Guide References
Footnotes
-
All public classes, methods, and functions must have Google-style docstrings, including a concise summary, comprehensive examples, and documentation for all parameters, return values, and exceptions. ↩
|
||
|
||
@keras_hub_export("keras_hub.models.RWKV7CausalLMPreprocessor") | ||
class RWKV7CausalLMPreprocessor(CausalLMPreprocessor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RWKV7CausalLMPreprocessor
class is missing a docstring. Please add a Google-style docstring explaining its purpose, parameters, and include a usage example, as specified in the style guide.1
Style Guide References
Footnotes
-
All public classes must have Google-style docstrings with a summary, examples, and documentation for parameters. ↩
super().__init__( | ||
inputs=token_id_input, | ||
outputs=sequence_output, | ||
dtype=dtype, | ||
**kwargs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backbone's __init__
method only accepts a single token_ids
tensor as input. For consistency with other models in keras_hub
and to improve interoperability, the backbone should be modified to accept a dictionary of inputs, including token_ids
and padding_mask
.1 The padding_mask
is currently computed inside the backbone, but it's better practice to have it as an explicit input.
Style Guide References
Footnotes
-
Use standardized names for model input arguments to ensure interoperability. For text models, this includes
token_ids
andpadding_mask
. The backbone should accept a dictionary of these inputs. ↩
RWKV7 is one of the strongest RNN models available today, and we now provide a full implementation for it in keras_hub.
📚 References
🔗 Pre-trained Checkpoints (ModelScope)
Numerical-verification notebook
This is the first modern RNN architecture in keras_hub. With the resurgence of recurrent models, more pre-trained RNN backbones will follow; hence this PR also serves as a reference implementation for future work.
Current progress