-
Notifications
You must be signed in to change notification settings - Fork 220
Representation Graphs
TensorRec allows you to define the algorithm that will be used to compute latent representations (also known as embeddings) of your users and items. You can define a custom representation function yourself, or you can use a pre-made representation function that comes with TensorRec in tensorrec.representation_graphs.
Representation functions for users or items are selected separately.
Calculates the representation by passing the features through a linear embedding.
Calculates the representation by passing the features through a linear embedding. Embeddings are L2 normalized, meaning all embeddings have equal magnitude. This can be useful as a user representation in mixture-of-tastes models, preventing one taste from having a much larger magnitude than others and dominating the recommendations.
Calculates the repesentations by passing the features through a single-layer ReLU neural network.
This abstract RepresentationGraph allows you to use Keras layers as a representation function by overriding the create_layers()
method.
An example of this can be found in examples/keras_example.py
.
import tensorflow as tf
import tensorrec
# Define a custom representation function graph
class TanhRepresentationGraph(tensorrec.representation_graphs.AbstractRepresentationGraph):
def connect_representation_graph(self, tf_features, n_components, n_features, node_name_ending):
"""
This representation function embeds the user/item features by passing them through a single tanh layer.
:param tf_features: tf.SparseTensor
The user/item features as a SparseTensor of dimensions [n_users/items, n_features]
:param n_components: int
The dimensionality of the resulting representation.
:param n_features: int
The number of features in tf_features
:param node_name_ending: String
Either 'user' or 'item'
:return:
A tuple of (tf.Tensor, list) where the first value is the resulting representation in n_components
dimensions and the second value is a list containing all tf.Variables which should be subject to
regularization.
"""
tf_tanh_weights = tf.Variable(tf.random_normal([n_features, n_components], stddev=.5),
name='tanh_weights_%s' % node_name_ending)
tf_repr = tf.nn.tanh(tf.sparse_tensor_dense_matmul(tf_features, tf_tanh_weights))
# Return repr layer and variables
return tf_repr, [tf_tanh_weights]
# Build a model with the custom representation function
model = tensorrec.TensorRec(user_repr_graph=TanhRepresentationGraph(),
item_repr_graph=TanhRepresentationGraph())
# Generate some dummy data
interactions, user_features, item_features = tensorrec.util.generate_dummy_data(
num_users=100,
num_items=150,
interaction_density=.05
)
# Fit the model for 5 epochs
model.fit(interactions, user_features, item_features, epochs=5, verbose=True)