Returns |
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade.md
new file mode 100644
index 0000000..563b9ed
--- /dev/null
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade.md
@@ -0,0 +1,82 @@
+description: TensorFlow Ranking Premade Orbit Task Module.
+
+
+
+
+
+
+
+
+# Module: tfr.keras.premade
+
+
+
+
+
+TensorFlow Ranking Premade Orbit Task Module.
+
+## Modules
+
+[`tfrbert_task`](../../tfr/keras/premade/tfrbert_task.md) module: TF-Ranking
+BERT task.
+
+## Classes
+
+[`class TFRBertConfig`](../../tfr/keras/premade/TFRBertConfig.md): The
+tf-ranking BERT task config.
+
+[`class TFRBertDataConfig`](../../tfr/keras/premade/TFRBertDataConfig.md): Data
+config for TFR-BERT task.
+
+[`class TFRBertDataLoader`](../../tfr/keras/premade/TFRBertDataLoader.md): A
+class to load dataset for TFR-BERT task.
+
+[`class TFRBertModelBuilder`](../../tfr/keras/premade/TFRBertModelBuilder.md):
+Model builder for TFR-BERT models.
+
+[`class TFRBertModelConfig`](../../tfr/keras/premade/TFRBertModelConfig.md): A
+TFR-BERT model configuration.
+
+[`class TFRBertScorer`](../../tfr/keras/premade/TFRBertScorer.md): Univariate
+BERT-based scorer.
+
+[`class TFRBertTask`](../../tfr/keras/premade/TFRBertTask.md): Task object for
+tf-ranking BERT.
+
+## Type Aliases
+
+[`TensorDict`](../../tfr/keras/premade/TensorDict.md): The central part of
+internal API.
+
+[`TensorLike`](../../tfr/keras/model/TensorLike.md): Union of all types that can
+be converted to a `tf.Tensor` by `tf.convert_to_tensor`.
+
+
+
+
+
+Other Members |
+
+
+
+DOCUMENT_ID
+ |
+
+`'document_id'`
+ |
+
+
+QUERY_ID
+ |
+
+`'query_id'`
+ |
+
+
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertConfig.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertConfig.md
new file mode 100644
index 0000000..08b43bf
--- /dev/null
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertConfig.md
@@ -0,0 +1,433 @@
+description: The tf-ranking BERT task config.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+# tfr.keras.premade.TFRBertConfig
+
+
+
+
+
+The tf-ranking BERT task config.
+
+Inherits From:
+[`RankingTaskConfig`](../../../tfr/keras/task/RankingTaskConfig.md)
+
+
+ View aliases
+
+Main aliases
+
`tfr.keras.premade.tfrbert_task.TFRBertConfig`
+
+
+
+
+tfr.keras.premade.TFRBertConfig(
+ default_params=None,
+ restrictions=None,
+ init_checkpoint='',
+ model: tfr.keras.premade.TFRBertModelConfig
= tfr.keras.premade.TFRBertConfig.model,
+ train_data=None,
+ validation_data=None,
+ loss='softmax_loss',
+ loss_reduction='none',
+ aggregated_metrics: bool = False,
+ output_preds: bool = False
+)
+
+
+
+
+
+
+
+Attributes |
+
+
+
+`default_params`
+ |
+
+Dataclass field
+ |
+
+
+`restrictions`
+ |
+
+Dataclass field
+ |
+
+
+`init_checkpoint`
+ |
+
+Dataclass field
+ |
+
+
+`model`
+ |
+
+Dataclass field
+ |
+
+
+`train_data`
+ |
+
+Dataclass field
+ |
+
+
+`validation_data`
+ |
+
+Dataclass field
+ |
+
+
+`loss`
+ |
+
+Dataclass field
+ |
+
+
+`loss_reduction`
+ |
+
+Dataclass field
+ |
+
+
+`aggregated_metrics`
+ |
+
+Dataclass field
+ |
+
+
+`output_preds`
+ |
+
+Dataclass field
+ |
+
+
+
+## Methods
+
+as_dict
+
+
+as_dict()
+
+
+Returns a dict representation of params_dict.ParamsDict.
+
+For the nested params_dict.ParamsDict, a nested dict will be returned.
+
+from_args
+
+
+@classmethod
+from_args(
+ *args, **kwargs
+)
+
+
+Builds a config from the given list of arguments.
+
+from_json
+
+
+@classmethod
+from_json(
+ file_path: str
+)
+
+
+Wrapper for `from_yaml`.
+
+from_yaml
+
+
+@classmethod
+from_yaml(
+ file_path: str
+)
+
+
+get
+
+
+get(
+ key, value=None
+)
+
+
+Accesses through built-in dictionary get method.
+
+lock
+
+
+lock()
+
+
+Makes the ParamsDict immutable.
+
+override
+
+
+override(
+ override_params, is_strict=True
+)
+
+
+Override the ParamsDict with a set of given params.
+
+
+
+
+
+Args |
+
+
+
+`override_params`
+ |
+
+a dict or a ParamsDict specifying the parameters to be
+overridden.
+ |
+
+
+`is_strict`
+ |
+
+a boolean specifying whether override is strict or not. If
+True, keys in `override_params` must be present in the ParamsDict. If
+False, keys in `override_params` can be different from what is currently
+defined in the ParamsDict. In this case, the ParamsDict will be extended
+to include the new keys.
+ |
+
+
+
+replace
+
+
+replace(
+ **kwargs
+)
+
+
+Overrides/returns a unlocked copy with the current config unchanged.
+
+validate
+
+
+validate()
+
+
+Validate the parameters consistency based on the restrictions.
+
+This method validates the internal consistency using the pre-defined list of
+restrictions. A restriction is defined as a string which specfiies a binary
+operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
+'>='}. Note that the meaning of these operators are consistent with the
+underlying Python immplementation. Users should make sure the define
+restrictions on their type make sense.
+
+For example, for a ParamsDict like the following `a: a1: 1 a2: 2 b: bb: bb1: 10
+bb2: 20 ccc: a1: 1 a3: 3` one can define two restrictions like this ['a.a1 ==
+b.ccc.a1', 'a.a2 <= b.bb.bb2']
+
+#### What it enforces are:
+
+- a.a1 = 1 == b.ccc.a1 = 1
+- a.a2 = 2 <= b.bb.bb2 = 20
+
+
+
+
+
+Raises |
+
+
+
+`KeyError`
+ |
+
+if any of the following happens
+(1) any of parameters in any of restrictions is not defined in
+ParamsDict,
+(2) any inconsistency violating the restriction is found.
+ |
+
+
+`ValueError`
+ |
+
+if the restriction defined in the string is not supported.
+ |
+
+
+
+__contains__
+
+
+__contains__(
+ key
+)
+
+
+Implements the membership test operator.
+
+__eq__
+
+
+__eq__(
+ other
+)
+
+
+
+
+
+
+Class Variables |
+
+
+
+IMMUTABLE_TYPES
+ |
+
+`(,
+ ,
+ ,
+ ,
+ )`
+ |
+
+
+RESERVED_ATTR
+ |
+
+`['_locked', '_restrictions']`
+ |
+
+
+SEQUENCE_TYPES
+ |
+
+`(, )`
+ |
+
+
+aggregated_metrics
+ |
+
+`False`
+ |
+
+
+default_params
+ |
+
+`None`
+ |
+
+
+init_checkpoint
+ |
+
+`''`
+ |
+
+
+loss
+ |
+
+`'softmax_loss'`
+ |
+
+
+loss_reduction
+ |
+
+`'none'`
+ |
+
+
+model
+ |
+
+Instance of tfr.keras.premade.TFRBertModelConfig
+ |
+
+
+output_preds
+ |
+
+`False`
+ |
+
+
+restrictions
+ |
+
+`None`
+ |
+
+
+train_data
+ |
+
+`None`
+ |
+
+
+validation_data
+ |
+
+`None`
+ |
+
+
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertDataConfig.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertDataConfig.md
new file mode 100644
index 0000000..3fd86a7
--- /dev/null
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertDataConfig.md
@@ -0,0 +1,753 @@
+description: Data config for TFR-BERT task.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+# tfr.keras.premade.TFRBertDataConfig
+
+
+
+
+
+Data config for TFR-BERT task.
+
+Inherits From:
+[`RankingDataConfig`](../../../tfr/keras/task/RankingDataConfig.md)
+
+
+ View aliases
+
+Main aliases
+
`tfr.keras.premade.tfrbert_task.TFRBertDataConfig`
+
+
+
+
+tfr.keras.premade.TFRBertDataConfig(
+ default_params=None,
+ restrictions=None,
+ input_path='',
+ tfds_name='',
+ tfds_split='',
+ global_batch_size=0,
+ is_training=True,
+ drop_remainder=True,
+ shuffle_buffer_size=100,
+ cache=False,
+ cycle_length=None,
+ block_length=1,
+ deterministic=None,
+ sharding=True,
+ enable_tf_data_service=False,
+ tf_data_service_address=None,
+ tf_data_service_job_name=None,
+ tfds_data_dir='',
+ tfds_as_supervised=False,
+ tfds_skip_decoding_feature='',
+ seed=None,
+ data_format='example_list_with_context',
+ dataset_fn='tfrecord',
+ list_size=None,
+ shuffle_examples=False,
+ convert_labels_to_binary=False,
+ mask_feature_name='example_list_mask',
+ seq_length: int = 128,
+ read_query_id: bool = False,
+ read_document_id: bool = False
+)
+
+
+
+
+
+
+
+Attributes |
+
+
+
+`default_params`
+ |
+
+Dataclass field
+ |
+
+
+`restrictions`
+ |
+
+Dataclass field
+ |
+
+
+`input_path`
+ |
+
+Dataclass field
+ |
+
+
+`tfds_name`
+ |
+
+Dataclass field
+ |
+
+
+`tfds_split`
+ |
+
+Dataclass field
+ |
+
+
+`global_batch_size`
+ |
+
+Dataclass field
+ |
+
+
+`is_training`
+ |
+
+Dataclass field
+ |
+
+
+`drop_remainder`
+ |
+
+Dataclass field
+ |
+
+
+`shuffle_buffer_size`
+ |
+
+Dataclass field
+ |
+
+
+`cache`
+ |
+
+Dataclass field
+ |
+
+
+`cycle_length`
+ |
+
+Dataclass field
+ |
+
+
+`block_length`
+ |
+
+Dataclass field
+ |
+
+
+`deterministic`
+ |
+
+Dataclass field
+ |
+
+
+`sharding`
+ |
+
+Dataclass field
+ |
+
+
+`enable_tf_data_service`
+ |
+
+Dataclass field
+ |
+
+
+`tf_data_service_address`
+ |
+
+Dataclass field
+ |
+
+
+`tf_data_service_job_name`
+ |
+
+Dataclass field
+ |
+
+
+`tfds_data_dir`
+ |
+
+Dataclass field
+ |
+
+
+`tfds_as_supervised`
+ |
+
+Dataclass field
+ |
+
+
+`tfds_skip_decoding_feature`
+ |
+
+Dataclass field
+ |
+
+
+`seed`
+ |
+
+Dataclass field
+ |
+
+
+`data_format`
+ |
+
+Dataclass field
+ |
+
+
+`dataset_fn`
+ |
+
+Dataclass field
+ |
+
+
+`list_size`
+ |
+
+Dataclass field
+ |
+
+
+`shuffle_examples`
+ |
+
+Dataclass field
+ |
+
+
+`convert_labels_to_binary`
+ |
+
+Dataclass field
+ |
+
+
+`mask_feature_name`
+ |
+
+Dataclass field
+ |
+
+
+`seq_length`
+ |
+
+Dataclass field
+ |
+
+
+`read_query_id`
+ |
+
+Dataclass field
+ |
+
+
+`read_document_id`
+ |
+
+Dataclass field
+ |
+
+
+
+## Methods
+
+as_dict
+
+
+as_dict()
+
+
+Returns a dict representation of params_dict.ParamsDict.
+
+For the nested params_dict.ParamsDict, a nested dict will be returned.
+
+from_args
+
+
+@classmethod
+from_args(
+ *args, **kwargs
+)
+
+
+Builds a config from the given list of arguments.
+
+from_json
+
+
+@classmethod
+from_json(
+ file_path: str
+)
+
+
+Wrapper for `from_yaml`.
+
+from_yaml
+
+
+@classmethod
+from_yaml(
+ file_path: str
+)
+
+
+get
+
+
+get(
+ key, value=None
+)
+
+
+Accesses through built-in dictionary get method.
+
+lock
+
+
+lock()
+
+
+Makes the ParamsDict immutable.
+
+override
+
+
+override(
+ override_params, is_strict=True
+)
+
+
+Override the ParamsDict with a set of given params.
+
+
+
+
+
+Args |
+
+
+
+`override_params`
+ |
+
+a dict or a ParamsDict specifying the parameters to be
+overridden.
+ |
+
+
+`is_strict`
+ |
+
+a boolean specifying whether override is strict or not. If
+True, keys in `override_params` must be present in the ParamsDict. If
+False, keys in `override_params` can be different from what is currently
+defined in the ParamsDict. In this case, the ParamsDict will be extended
+to include the new keys.
+ |
+
+
+
+replace
+
+
+replace(
+ **kwargs
+)
+
+
+Overrides/returns a unlocked copy with the current config unchanged.
+
+validate
+
+
+validate()
+
+
+Validate the parameters consistency based on the restrictions.
+
+This method validates the internal consistency using the pre-defined list of
+restrictions. A restriction is defined as a string which specfiies a binary
+operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
+'>='}. Note that the meaning of these operators are consistent with the
+underlying Python immplementation. Users should make sure the define
+restrictions on their type make sense.
+
+For example, for a ParamsDict like the following `a: a1: 1 a2: 2 b: bb: bb1: 10
+bb2: 20 ccc: a1: 1 a3: 3` one can define two restrictions like this ['a.a1 ==
+b.ccc.a1', 'a.a2 <= b.bb.bb2']
+
+#### What it enforces are:
+
+- a.a1 = 1 == b.ccc.a1 = 1
+- a.a2 = 2 <= b.bb.bb2 = 20
+
+
+
+
+
+Raises |
+
+
+
+`KeyError`
+ |
+
+if any of the following happens
+(1) any of parameters in any of restrictions is not defined in
+ParamsDict,
+(2) any inconsistency violating the restriction is found.
+ |
+
+
+`ValueError`
+ |
+
+if the restriction defined in the string is not supported.
+ |
+
+
+
+__contains__
+
+
+__contains__(
+ key
+)
+
+
+Implements the membership test operator.
+
+__eq__
+
+
+__eq__(
+ other
+)
+
+
+
+
+
+
+Class Variables |
+
+
+
+IMMUTABLE_TYPES
+ |
+
+`(,
+ ,
+ ,
+ ,
+ )`
+ |
+
+
+RESERVED_ATTR
+ |
+
+`['_locked', '_restrictions']`
+ |
+
+
+SEQUENCE_TYPES
+ |
+
+`(, )`
+ |
+
+
+block_length
+ |
+
+`1`
+ |
+
+
+cache
+ |
+
+`False`
+ |
+
+
+convert_labels_to_binary
+ |
+
+`False`
+ |
+
+
+cycle_length
+ |
+
+`None`
+ |
+
+
+data_format
+ |
+
+`'example_list_with_context'`
+ |
+
+
+dataset_fn
+ |
+
+`'tfrecord'`
+ |
+
+
+default_params
+ |
+
+`None`
+ |
+
+
+deterministic
+ |
+
+`None`
+ |
+
+
+drop_remainder
+ |
+
+`True`
+ |
+
+
+enable_tf_data_service
+ |
+
+`False`
+ |
+
+
+global_batch_size
+ |
+
+`0`
+ |
+
+
+input_path
+ |
+
+`''`
+ |
+
+
+is_training
+ |
+
+`True`
+ |
+
+
+list_size
+ |
+
+`None`
+ |
+
+
+mask_feature_name
+ |
+
+`'example_list_mask'`
+ |
+
+
+read_document_id
+ |
+
+`False`
+ |
+
+
+read_query_id
+ |
+
+`False`
+ |
+
+
+restrictions
+ |
+
+`None`
+ |
+
+
+seed
+ |
+
+`None`
+ |
+
+
+seq_length
+ |
+
+`128`
+ |
+
+
+sharding
+ |
+
+`True`
+ |
+
+
+shuffle_buffer_size
+ |
+
+`100`
+ |
+
+
+shuffle_examples
+ |
+
+`False`
+ |
+
+
+tf_data_service_address
+ |
+
+`None`
+ |
+
+
+tf_data_service_job_name
+ |
+
+`None`
+ |
+
+
+tfds_as_supervised
+ |
+
+`False`
+ |
+
+
+tfds_data_dir
+ |
+
+`''`
+ |
+
+
+tfds_name
+ |
+
+`''`
+ |
+
+
+tfds_skip_decoding_feature
+ |
+
+`''`
+ |
+
+
+tfds_split
+ |
+
+`''`
+ |
+
+
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertDataLoader.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertDataLoader.md
new file mode 100644
index 0000000..a44c45f
--- /dev/null
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertDataLoader.md
@@ -0,0 +1,59 @@
+description: A class to load dataset for TFR-BERT task.
+
+
+
+
+
+
+
+
+# tfr.keras.premade.TFRBertDataLoader
+
+
+
+
+
+A class to load dataset for TFR-BERT task.
+
+Inherits From:
+[`RankingDataLoader`](../../../tfr/keras/task/RankingDataLoader.md)
+
+
+ View aliases
+
+Main aliases
+
`tfr.keras.premade.tfrbert_task.TFRBertDataLoader`
+
+
+
+
+tfr.keras.premade.TFRBertDataLoader(
+ params,
+ label_spec: Tuple[str, tf.io.FixedLenFeature] = None,
+ **kwargs
+)
+
+
+
+
+## Methods
+
+load
+
+View
+source
+
+
+load(
+ input_context: Optional[tf.distribute.InputContext] = None
+) -> tf.data.Dataset
+
+
+Returns a tf.dataset.Dataset.
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertModelBuilder.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertModelBuilder.md
new file mode 100644
index 0000000..b269ac6
--- /dev/null
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertModelBuilder.md
@@ -0,0 +1,173 @@
+description: Model builder for TFR-BERT models.
+
+
+
+
+
+
+
+
+
+
+
+# tfr.keras.premade.TFRBertModelBuilder
+
+
+
+
+
+Model builder for TFR-BERT models.
+
+Inherits From: [`ModelBuilder`](../../../tfr/keras/model/ModelBuilder.md),
+[`ModelBuilderWithMask`](../../../tfr/keras/model/ModelBuilderWithMask.md),
+[`AbstractModelBuilder`](../../../tfr/keras/model/AbstractModelBuilder.md)
+
+
+ View aliases
+
+Main aliases
+
`tfr.keras.premade.tfrbert_task.TFRBertModelBuilder`
+
+
+
+
+tfr.keras.premade.TFRBertModelBuilder(
+ input_creator: Callable[[], Tuple[TensorDict, TensorDict]],
+ preprocessor: Callable[[TensorDict, TensorDict, tf.Tensor], Tuple[TensorDict, TensorDict]],
+ scorer: Callable[[TensorDict, TensorDict, tf.Tensor], Union[TensorLike, TensorDict]],
+ mask_feature_name: str,
+ name: Optional[str] = None
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+`input_creator`
+ |
+
+A callable or a class like `InputCreator` to implement
+`create_inputs`.
+ |
+
+
+`preprocessor`
+ |
+
+A callable or a class like `Preprocessor` to implement
+`preprocess`.
+ |
+
+
+`scorer`
+ |
+
+A callable or a class like `Scorer` to implement `score`.
+ |
+
+
+`mask_feature_name`
+ |
+
+name of 2D mask boolean feature.
+ |
+
+
+`name`
+ |
+
+(optional) name of the Model.
+ |
+
+
+
+## Methods
+
+build
+
+View
+source
+
+
+build() -> tf.keras.Model
+
+
+Builds a Keras Model for Ranking Pipeline.
+
+#### Example usage:
+
+```python
+model_builder = SimpleModelBuilder(
+ {},
+ {"example_feature_1": tf.io.FixedLenFeature(
+ shape=(1,), dtype=tf.float32, default_value=0.0)},
+ "list_mask", "model_builder")
+model = model_builder.build()
+```
+
+
+
+
+
+Returns |
+
+
+A `tf.keras.Model`.
+ |
+
+
+
+
+
+
+View
+source
+
+
+create_inputs() -> Tuple[tfr.keras.model.TensorDict
, tfr.keras.model.TensorDict
, tf.Tensor]
+
+
+See `ModelBuilderWithMask`.
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ context_inputs: tfr.keras.model.TensorDict
,
+ example_inputs: tfr.keras.model.TensorDict
,
+ mask: tf.Tensor
+) -> Tuple[tfr.keras.model.TensorDict
, tfr.keras.model.TensorDict
]
+
+
+See `ModelBuilderWithMask`.
+
+score
+
+View
+source
+
+
+score(
+ context_features: tfr.keras.model.TensorDict
,
+ example_features: tfr.keras.model.TensorDict
,
+ mask: tf.Tensor
+) -> Union[TensorLike, TensorDict]
+
+
+See `ModelBuilderWithMask`.
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertModelConfig.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertModelConfig.md
new file mode 100644
index 0000000..6cfc936
--- /dev/null
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertModelConfig.md
@@ -0,0 +1,334 @@
+description: A TFR-BERT model configuration.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+# tfr.keras.premade.TFRBertModelConfig
+
+
+
+
+
+A TFR-BERT model configuration.
+
+
+ View aliases
+
+Main aliases
+
`tfr.keras.premade.tfrbert_task.TFRBertModelConfig`
+
+
+
+
+tfr.keras.premade.TFRBertModelConfig(
+ default_params=None,
+ restrictions=None,
+ dropout_rate: float = 0.1,
+ encoder: encoders.EncoderConfig = tfr.keras.premade.TFRBertModelConfig.encoder
+)
+
+
+
+
+
+
+
+Attributes |
+
+
+
+`default_params`
+ |
+
+Dataclass field
+ |
+
+
+`restrictions`
+ |
+
+Dataclass field
+ |
+
+
+`dropout_rate`
+ |
+
+Dataclass field
+ |
+
+
+`encoder`
+ |
+
+Dataclass field
+ |
+
+
+
+## Methods
+
+as_dict
+
+
+as_dict()
+
+
+Returns a dict representation of params_dict.ParamsDict.
+
+For the nested params_dict.ParamsDict, a nested dict will be returned.
+
+from_args
+
+
+@classmethod
+from_args(
+ *args, **kwargs
+)
+
+
+Builds a config from the given list of arguments.
+
+from_json
+
+
+@classmethod
+from_json(
+ file_path: str
+)
+
+
+Wrapper for `from_yaml`.
+
+from_yaml
+
+
+@classmethod
+from_yaml(
+ file_path: str
+)
+
+
+get
+
+
+get(
+ key, value=None
+)
+
+
+Accesses through built-in dictionary get method.
+
+lock
+
+
+lock()
+
+
+Makes the ParamsDict immutable.
+
+override
+
+
+override(
+ override_params, is_strict=True
+)
+
+
+Override the ParamsDict with a set of given params.
+
+
+
+
+
+Args |
+
+
+
+`override_params`
+ |
+
+a dict or a ParamsDict specifying the parameters to be
+overridden.
+ |
+
+
+`is_strict`
+ |
+
+a boolean specifying whether override is strict or not. If
+True, keys in `override_params` must be present in the ParamsDict. If
+False, keys in `override_params` can be different from what is currently
+defined in the ParamsDict. In this case, the ParamsDict will be extended
+to include the new keys.
+ |
+
+
+
+replace
+
+
+replace(
+ **kwargs
+)
+
+
+Overrides/returns a unlocked copy with the current config unchanged.
+
+validate
+
+
+validate()
+
+
+Validate the parameters consistency based on the restrictions.
+
+This method validates the internal consistency using the pre-defined list of
+restrictions. A restriction is defined as a string which specfiies a binary
+operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
+'>='}. Note that the meaning of these operators are consistent with the
+underlying Python immplementation. Users should make sure the define
+restrictions on their type make sense.
+
+For example, for a ParamsDict like the following `a: a1: 1 a2: 2 b: bb: bb1: 10
+bb2: 20 ccc: a1: 1 a3: 3` one can define two restrictions like this ['a.a1 ==
+b.ccc.a1', 'a.a2 <= b.bb.bb2']
+
+#### What it enforces are:
+
+- a.a1 = 1 == b.ccc.a1 = 1
+- a.a2 = 2 <= b.bb.bb2 = 20
+
+
+
+
+
+Raises |
+
+
+
+`KeyError`
+ |
+
+if any of the following happens
+(1) any of parameters in any of restrictions is not defined in
+ParamsDict,
+(2) any inconsistency violating the restriction is found.
+ |
+
+
+`ValueError`
+ |
+
+if the restriction defined in the string is not supported.
+ |
+
+
+
+__contains__
+
+
+__contains__(
+ key
+)
+
+
+Implements the membership test operator.
+
+__eq__
+
+
+__eq__(
+ other
+)
+
+
+
+
+
+
+Class Variables |
+
+
+
+IMMUTABLE_TYPES
+ |
+
+`(,
+ ,
+ ,
+ ,
+ )`
+ |
+
+
+RESERVED_ATTR
+ |
+
+`['_locked', '_restrictions']`
+ |
+
+
+SEQUENCE_TYPES
+ |
+
+`(, )`
+ |
+
+
+default_params
+ |
+
+`None`
+ |
+
+
+dropout_rate
+ |
+
+`0.1`
+ |
+
+
+encoder
+ |
+
+Instance of `tensorflow_models.official.nlp.configs.encoders.EncoderConfig`
+ |
+
+
+restrictions
+ |
+
+`None`
+ |
+
+
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertScorer.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertScorer.md
new file mode 100644
index 0000000..ae06cc3
--- /dev/null
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertScorer.md
@@ -0,0 +1,63 @@
+description: Univariate BERT-based scorer.
+
+
+
+
+
+
+
+
+# tfr.keras.premade.TFRBertScorer
+
+
+
+
+
+Univariate BERT-based scorer.
+
+Inherits From:
+[`UnivariateScorer`](../../../tfr/keras/model/UnivariateScorer.md),
+[`Scorer`](../../../tfr/keras/model/Scorer.md)
+
+
+ View aliases
+
+Main aliases
+
`tfr.keras.premade.tfrbert_task.TFRBertScorer`
+
+
+
+
+tfr.keras.premade.TFRBertScorer(
+ encoder: tf.keras.Model,
+ bert_output_dropout: float,
+ name: str = 'tfrbert',
+ **kwargs
+)
+
+
+
+
+## Methods
+
+__call__
+
+View
+source
+
+
+__call__(
+ context_features: tfr.keras.model.TensorDict
,
+ example_features: tfr.keras.model.TensorDict
,
+ mask: tf.Tensor
+) -> Union[tf.Tensor, tfr.keras.model.TensorDict
]
+
+
+See `Scorer`.
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertTask.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertTask.md
new file mode 100644
index 0000000..437cf03
--- /dev/null
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TFRBertTask.md
@@ -0,0 +1,686 @@
+description: Task object for tf-ranking BERT.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+# tfr.keras.premade.TFRBertTask
+
+
+
+
+
+Task object for tf-ranking BERT.
+
+Inherits From: [`RankingTask`](../../../tfr/keras/task/RankingTask.md)
+
+
+ View aliases
+
+Main aliases
+
`tfr.keras.premade.tfrbert_task.TFRBertTask`
+
+
+
+
+tfr.keras.premade.TFRBertTask(
+ params,
+ label_spec: Tuple[str, tf.io.FixedLenFeature] = None,
+ logging_dir: Optional[str] = None,
+ name: Optional[str] = None,
+ **kwargs
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+`params`
+ |
+
+the task configuration instance, which can be any of dataclass,
+ConfigDict, namedtuple, etc.
+ |
+
+
+`logging_dir`
+ |
+
+a string pointing to where the model, summaries etc. will be
+saved. You can also write additional stuff in this directory.
+ |
+
+
+`name`
+ |
+
+the task name.
+ |
+
+
+
+
+
+
+
+Attributes |
+
+ `logging_dir` |
+
+ |
`name` | Returns the name of this module as passed
+or determined in the ctor.
+
+NOTE: This is not the same as the `self.name_scope.name` which includes parent
+module names. |
`name_scope` | Returns a
+`tf.name_scope` instance for this class. |
+`non_trainable_variables` | Sequence of non-trainable variables owned
+by this module and its submodules.
+
+Note: this method uses reflection to find variables on the current instance and
+submodules. For performance reasons you may wish to cache the result of calling
+this method if you don't expect the return value to change. |
+`submodules` | Sequence of all sub-modules.
+
+Submodules are modules which are properties of this module, or found as
+properties of modules which are properties of this module (and so on).
+
+```
+>>> a = tf.Module()
+>>> b = tf.Module()
+>>> c = tf.Module()
+>>> a.b = b
+>>> b.c = c
+>>> list(a.submodules) == [b, c]
+True
+>>> list(b.submodules) == [c]
+True
+>>> list(c.submodules) == []
+True
+```
+
+ |
`task_config` |
+
+ |
`trainable_variables` | Sequence of trainable
+variables owned by this module and its submodules.
+
+Note: this method uses reflection to find variables on the current instance and
+submodules. For performance reasons you may wish to cache the result of calling
+this method if you don't expect the return value to change. |
+`variables` | Sequence of variables owned by this module and its
+submodules.
+
+Note: this method uses reflection to find variables on the current instance
+and submodules. For performance reasons you may wish to cache the result
+of calling this method if you don't expect the return value to change.
+ |
+
+
+
+## Methods
+
+aggregate_logs
+
+View
+source
+
+
+aggregate_logs(
+ state=None, step_outputs=None
+)
+
+
+Aggregates over logs. This runs on CPU in eager mode.
+
+
+
+View
+source
+
+
+build_inputs(
+ params, input_context=None
+)
+
+
+Returns tf.data.Dataset for tf-ranking BERT task.
+
+build_losses
+
+View
+source
+
+
+build_losses(
+ labels, model_outputs, aux_losses=None
+) -> tf.Tensor
+
+
+Standard interface to compute losses.
+
+
+
+
+
+Args |
+
+
+
+`labels`
+ |
+
+optional label tensors.
+ |
+
+
+`model_outputs`
+ |
+
+a nested structure of output tensors.
+ |
+
+
+`aux_losses`
+ |
+
+auxiliary loss tensors, i.e. `losses` in keras.Model.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The total loss tensor.
+ |
+
+
+
+
+build_metrics
+
+View
+source
+
+
+build_metrics(
+ training=None
+)
+
+
+Gets streaming metrics for training/validation.
+
+build_model
+
+View
+source
+
+
+build_model()
+
+
+[Optional] Creates model architecture.
+
+
+
+
+
+Returns |
+
+
+A model instance.
+ |
+
+
+
+
+create_optimizer
+
+
+@classmethod
+create_optimizer(
+ optimizer_config: OptimizationConfig,
+ runtime_config: Optional[RuntimeConfig] = None
+)
+
+
+Creates an TF optimizer from configurations.
+
+
+
+
+
+Args |
+
+
+
+`optimizer_config`
+ |
+
+the parameters of the Optimization settings.
+ |
+
+
+`runtime_config`
+ |
+
+the parameters of the runtime.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tf.optimizers.Optimizer object.
+ |
+
+
+
+
+inference_step
+
+
+inference_step(
+ inputs,
+ model: tf.keras.Model
+)
+
+
+Performs the forward step.
+
+With distribution strategies, this method runs on devices.
+
+
+
+
+
+Args |
+
+
+
+`inputs`
+ |
+
+a dictionary of input tensors.
+ |
+
+
+`model`
+ |
+
+the keras.Model.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+Model outputs.
+ |
+
+
+
+
+initialize
+
+View
+source
+
+
+initialize(
+ model
+)
+
+
+Load a pretrained checkpoint (if exists) and then train from iter 0.
+
+process_compiled_metrics
+
+
+process_compiled_metrics(
+ compiled_metrics, labels, model_outputs
+)
+
+
+Process and update compiled_metrics.
+
+call when using compile/fit API.
+
+
+
+
+
+Args |
+
+
+
+`compiled_metrics`
+ |
+
+the compiled metrics (model.compiled_metrics).
+ |
+
+
+`labels`
+ |
+
+a tensor or a nested structure of tensors.
+ |
+
+
+`model_outputs`
+ |
+
+a tensor or a nested structure of tensors. For example,
+output of the keras model built by self.build_model.
+ |
+
+
+
+process_metrics
+
+View
+source
+
+
+process_metrics(
+ metrics, labels, model_outputs
+)
+
+
+Process and update metrics.
+
+Called when using custom training loop API.
+
+
+
+
+
+Args |
+
+
+
+`metrics`
+ |
+
+a nested structure of metrics objects. The return of function
+self.build_metrics.
+ |
+
+
+`labels`
+ |
+
+a tensor or a nested structure of tensors.
+ |
+
+
+`model_outputs`
+ |
+
+a tensor or a nested structure of tensors. For example,
+output of the keras model built by self.build_model.
+ |
+
+
+
+reduce_aggregated_logs
+
+View
+source
+
+
+reduce_aggregated_logs(
+ aggregated_logs, global_step=None
+)
+
+
+Calculates aggregated metrics and writes predictions to csv.
+
+train_step
+
+View
+source
+
+
+train_step(
+ inputs,
+ model: tf.keras.Model,
+ optimizer: tf.keras.optimizers.Optimizer,
+ metrics
+)
+
+
+Does forward and backward.
+
+With distribution strategies, this method runs on devices.
+
+
+
+
+
+Args |
+
+
+
+`inputs`
+ |
+
+a dictionary of input tensors.
+ |
+
+
+`model`
+ |
+
+the model, forward pass definition.
+ |
+
+
+`optimizer`
+ |
+
+the optimizer for this training step.
+ |
+
+
+`metrics`
+ |
+
+a nested structure of metrics objects.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A dictionary of logs.
+ |
+
+
+
+
+validation_step
+
+View
+source
+
+
+validation_step(
+ inputs,
+ model: tf.keras.Model,
+ metrics=None
+)
+
+
+Validation step.
+
+With distribution strategies, this method runs on devices.
+
+
+
+
+
+Args |
+
+
+
+`inputs`
+ |
+
+a dictionary of input tensors.
+ |
+
+
+`model`
+ |
+
+the keras.Model.
+ |
+
+
+`metrics`
+ |
+
+a nested structure of metrics objects.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A dictionary of logs.
+ |
+
+
+
+
+with_name_scope
+
+
+@classmethod
+with_name_scope(
+ method
+)
+
+
+Decorator to automatically enter the module name scope.
+
+```
+>>> class MyModule(tf.Module):
+... @tf.Module.with_name_scope
+... def __call__(self, x):
+... if not hasattr(self, 'w'):
+... self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
+... return tf.matmul(x, self.w)
+```
+
+Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose names
+included the module name:
+
+```
+>>> mod = MyModule()
+>>> mod(tf.ones([1, 2]))
+
+>>> mod.w
+
+```
+
+
+
+
+
+Args |
+
+
+
+`method`
+ |
+
+The method to wrap.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The original method wrapped such that it enters the module's name scope.
+ |
+
+
+
+
+
+
+
+
+Class Variables |
+
+
+
+loss
+ |
+
+`'loss'`
+ |
+
+
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TensorDict.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TensorDict.md
new file mode 100644
index 0000000..d981340
--- /dev/null
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/TensorDict.md
@@ -0,0 +1,31 @@
+description: The central part of internal API.
+
+
+
+
+
+
+# tfr.keras.premade.TensorDict
+
+
+
+This symbol is a **type alias**.
+
+The central part of internal API.
+
+#### Source:
+
+
+TensorDict = [
+ str,
+ tfr.keras.model.TensorLike
+]
+
+
+
+
+This represents a generic version of type 'origin' with type arguments 'params'.
+There are two kind of these aliases: user defined and special. The special ones
+are wrappers around builtin collections and ABCs in collections.abc. These must
+have 'name' always set. If 'inst' is False, then the alias can't be
+instantiated, this is used by e.g. typing.List and typing.Dict.
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/tfrbert_task.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/tfrbert_task.md
new file mode 100644
index 0000000..14760d2
--- /dev/null
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/premade/tfrbert_task.md
@@ -0,0 +1,77 @@
+description: TF-Ranking BERT task.
+
+
+
+
+
+
+
+
+# Module: tfr.keras.premade.tfrbert_task
+
+
+
+
+
+TF-Ranking BERT task.
+
+## Classes
+
+[`class TFRBertConfig`](../../../tfr/keras/premade/TFRBertConfig.md): The
+tf-ranking BERT task config.
+
+[`class TFRBertDataConfig`](../../../tfr/keras/premade/TFRBertDataConfig.md):
+Data config for TFR-BERT task.
+
+[`class TFRBertDataLoader`](../../../tfr/keras/premade/TFRBertDataLoader.md): A
+class to load dataset for TFR-BERT task.
+
+[`class TFRBertModelBuilder`](../../../tfr/keras/premade/TFRBertModelBuilder.md):
+Model builder for TFR-BERT models.
+
+[`class TFRBertModelConfig`](../../../tfr/keras/premade/TFRBertModelConfig.md):
+A TFR-BERT model configuration.
+
+[`class TFRBertScorer`](../../../tfr/keras/premade/TFRBertScorer.md): Univariate
+BERT-based scorer.
+
+[`class TFRBertTask`](../../../tfr/keras/premade/TFRBertTask.md): Task object
+for tf-ranking BERT.
+
+## Type Aliases
+
+[`TensorDict`](../../../tfr/keras/premade/TensorDict.md): The central part of
+internal API.
+
+[`TensorLike`](../../../tfr/keras/model/TensorLike.md): Union of all types that
+can be converted to a `tf.Tensor` by `tf.convert_to_tensor`.
+
+
+
+
+
+Other Members |
+
+
+
+DOCUMENT_ID
+ |
+
+`'document_id'`
+ |
+
+
+QUERY_ID
+ |
+
+`'query_id'`
+ |
+
+
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task.md
index 8f9340d..9c6ad08 100644
--- a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task.md
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task.md
@@ -42,7 +42,6 @@ TF-Ranking task config.
internal API.
-
Other Members |
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/FeatureSpec.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/FeatureSpec.md
index ee9a72a..10960c2 100644
--- a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/FeatureSpec.md
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/FeatureSpec.md
@@ -8,7 +8,6 @@ description: The central part of internal API.
# tfr.keras.task.FeatureSpec
-
This symbol is a **type alias**.
The central part of internal API.
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/RankingDataConfig.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/RankingDataConfig.md
index bb0a227..cb7291f 100644
--- a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/RankingDataConfig.md
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/RankingDataConfig.md
@@ -95,8 +95,8 @@ Data set config.
-
+
Attributes |
@@ -365,7 +365,6 @@ Makes the ParamsDict immutable.
Override the ParamsDict with a set of given params.
-
Args |
@@ -427,7 +426,6 @@ b.ccc.a1', 'a.a2 <= b.bb.bb2']
- a.a2 = 2 <= b.bb.bb2 = 20
-
Raises |
@@ -471,7 +469,6 @@ Implements the membership test operator.
-
Class Variables |
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/RankingTask.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/RankingTask.md
index 8ec1b15..9e28586 100644
--- a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/RankingTask.md
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/RankingTask.md
@@ -50,8 +50,8 @@ Task object for TF-Ranking.
-
+
Args |
@@ -83,7 +83,6 @@ the task name.
-
Attributes |
@@ -160,7 +159,6 @@ from a validation step may be a tuple with elements from replicas, and a
concatenation of the elements is needed in such case.
-
Args |
@@ -200,7 +198,6 @@ Dataset functions define per-host datasets with the per-replica batch size. With
distributed training, this method runs on remote hosts.
-
Args |
@@ -224,7 +221,6 @@ optional distribution input pipeline context.
-
Returns |
@@ -250,7 +246,6 @@ source
Standard interface to compute losses.
-
Args |
@@ -280,7 +275,6 @@ auxiliary loss tensors, i.e. `losses` in keras.Model.
-
Returns |
@@ -317,7 +311,6 @@ source
[Optional] Creates model architecture.
-
Returns |
@@ -342,7 +335,6 @@ A model instance.
Creates an TF optimizer from configurations.
-
Args |
@@ -365,7 +357,6 @@ the parameters of the runtime.
-
Returns |
@@ -391,7 +382,6 @@ Performs the forward step.
With distribution strategies, this method runs on devices.
-
Args |
@@ -414,7 +404,6 @@ the keras.Model.
-
Returns |
@@ -442,7 +431,6 @@ called. You can use this callback function to load a pretrained checkpoint,
saved under a directory other than the model_dir.
-
Args |
@@ -470,7 +458,6 @@ Process and update compiled_metrics.
call when using compile/fit API.
-
Args |
@@ -516,7 +503,6 @@ Process and update metrics.
Called when using custom training loop API.
-
Args |
@@ -563,7 +549,6 @@ to compute the final metrics. It runs on CPU and in each eval_end() in base
trainer (see eval_end() function in official/core/base_trainer.py).
-
Args |
@@ -586,7 +571,6 @@ An optional variable of global step.
-
Returns |
@@ -617,7 +601,6 @@ Does forward and backward.
With distribution strategies, this method runs on devices.
-
Args |
@@ -654,7 +637,6 @@ a nested structure of metrics objects.
-
Returns |
@@ -684,7 +666,6 @@ Validation step.
With distribution strategies, this method runs on devices.
-
Args |
@@ -714,7 +695,6 @@ a nested structure of metrics objects.
-
Returns |
@@ -759,7 +739,6 @@ numpy=..., dtype=float32)>
```
-
Args |
@@ -775,7 +754,6 @@ The method to wrap.
-
Returns |
@@ -788,7 +766,6 @@ The original method wrapped such that it enters the module's name scope.
-
Class Variables |
diff --git a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/RankingTaskConfig.md b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/RankingTaskConfig.md
index 9001792..8e1cf3a 100644
--- a/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/RankingTaskConfig.md
+++ b/tensorflow_ranking/g3doc/api_docs/python/tfr/keras/task/RankingTaskConfig.md
@@ -61,8 +61,8 @@ The TF-Ranking task config.
-
+
Attributes |
@@ -212,7 +212,6 @@ Makes the ParamsDict immutable.
Override the ParamsDict with a set of given params.
-
Args |
@@ -274,7 +273,6 @@ b.ccc.a1', 'a.a2 <= b.bb.bb2']
- a.a2 = 2 <= b.bb.bb2 = 20
-
Raises |
@@ -318,7 +316,6 @@ Implements the membership test operator.
-
Class Variables |
diff --git a/tensorflow_ranking/python/keras/BUILD b/tensorflow_ranking/python/keras/BUILD
index c9a1ebd..505b861 100644
--- a/tensorflow_ranking/python/keras/BUILD
+++ b/tensorflow_ranking/python/keras/BUILD
@@ -12,7 +12,6 @@ py_library(
name = "keras",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
- visibility = ["//visibility:public"],
deps = [
":estimator",
":feature",
@@ -27,6 +26,7 @@ py_library(
":task",
":utils",
"//tensorflow_ranking/python/keras/canned",
+ "//tensorflow_ranking/python/keras/premade",
],
)
diff --git a/tensorflow_ranking/python/keras/__init__.py b/tensorflow_ranking/python/keras/__init__.py
index 196ec08..4a91071 100644
--- a/tensorflow_ranking/python/keras/__init__.py
+++ b/tensorflow_ranking/python/keras/__init__.py
@@ -23,6 +23,7 @@
from tensorflow_ranking.python.keras import model
from tensorflow_ranking.python.keras import network
from tensorflow_ranking.python.keras import pipeline
+from tensorflow_ranking.python.keras import premade
from tensorflow_ranking.python.keras import saved_model
from tensorflow_ranking.python.keras import strategy_utils
from tensorflow_ranking.python.keras import task
diff --git a/tensorflow_ranking/python/keras/canned/BUILD b/tensorflow_ranking/python/keras/canned/BUILD
index c1e89ed..36ebb34 100644
--- a/tensorflow_ranking/python/keras/canned/BUILD
+++ b/tensorflow_ranking/python/keras/canned/BUILD
@@ -1,6 +1,10 @@
# TensorFlow Ranking Keras canned models.
-package(default_visibility = ["//visibility:public"])
+package(
+ default_visibility = [
+ "//tensorflow_ranking:__subpackages__",
+ ],
+)
licenses(["notice"])
@@ -8,7 +12,6 @@ py_library(
name = "canned",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
- visibility = ["//visibility:public"],
deps = [
":dnn",
":gam",
diff --git a/tensorflow_ranking/python/keras/premade/BUILD b/tensorflow_ranking/python/keras/premade/BUILD
index 46120a4..5a3cff2 100644
--- a/tensorflow_ranking/python/keras/premade/BUILD
+++ b/tensorflow_ranking/python/keras/premade/BUILD
@@ -1,9 +1,22 @@
"""TFR-BERT."""
-package(default_visibility = ["//visibility:public"])
+package(
+ default_visibility = [
+ "//tensorflow_ranking:__subpackages__",
+ ],
+)
licenses(["notice"])
+py_library(
+ name = "premade",
+ srcs = ["__init__.py"],
+ srcs_version = "PY3",
+ deps = [
+ ":tfrbert_task",
+ ],
+)
+
py_library(
name = "tfrbert_task",
srcs = ["tfrbert_task.py"],
diff --git a/tensorflow_ranking/python/keras/premade/__init__.py b/tensorflow_ranking/python/keras/premade/__init__.py
new file mode 100644
index 0000000..28859c3
--- /dev/null
+++ b/tensorflow_ranking/python/keras/premade/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2021 The TensorFlow Ranking Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""TensorFlow Ranking Premade Orbit Task Module."""
+
+from tensorflow_ranking.python.keras.premade.tfrbert_task import * # pylint: disable=wildcard-import,line-too-long
diff --git a/tensorflow_ranking/python/version.py b/tensorflow_ranking/python/version.py
index f1db59c..9d4bb58 100644
--- a/tensorflow_ranking/python/version.py
+++ b/tensorflow_ranking/python/version.py
@@ -17,7 +17,7 @@
# We follow Semantic Versioning (https://semver.org/)
_MAJOR_VERSION = '0'
_MINOR_VERSION = '4'
-_PATCH_VERSION = '1'
+_PATCH_VERSION = '2'
# When building releases, we can update this value on the release branch to
# reflect the current release candidate ('rc0', 'rc1') or, finally, the official
diff --git a/tensorflow_ranking/tools/pip_package/setup.py b/tensorflow_ranking/tools/pip_package/setup.py
index f844aa5..beda1b1 100644
--- a/tensorflow_ranking/tools/pip_package/setup.py
+++ b/tensorflow_ranking/tools/pip_package/setup.py
@@ -28,7 +28,7 @@
# This version string is semver compatible, but incompatible with pip.
# For pip, we will remove all '-' characters from this string, and use the
# result for pip.
-_VERSION = '0.4.1'
+_VERSION = '0.4.2'
REQUIRED_PACKAGES = [
'absl-py >= 0.1.6', 'numpy >= 1.13.3', 'six >= 1.10.0',