Skip to content

Commit

Permalink
pg.evolve to count the root element as candidate for mutation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611146302
  • Loading branch information
daiyip authored and pyglove authors committed Feb 28, 2024
1 parent 489377c commit 73ef9db
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
17 changes: 10 additions & 7 deletions pyglove/core/hyper/evolvable.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class MutationPoint:
mutation_type: 'MutationType'
location: object_utils.KeyPath
old_value: Any
parent: symbolic.Symbolic
parent: Optional[symbolic.Symbolic]


class Evolvable(custom.CustomHyper):
Expand Down Expand Up @@ -76,7 +76,6 @@ def _choose_mutation_point(k: object_utils.KeyPath,
p: Optional[symbolic.Symbolic]):
"""Visiting function for a symbolic node."""
def _add_point(mt: MutationType, k=k, v=v, p=p):
assert p is not None
mutation_points.append(MutationPoint(mt, k, v, p))
mutation_weights.append(self._weights(mt, k, v, p))

Expand All @@ -85,7 +84,8 @@ def _add_point(mt: MutationType, k=k, v=v, p=p):
f = p.sym_attr_field(k.key)
if f and f.metadata and 'no_mutation' in f.metadata:
return symbolic.TraverseAction.CONTINUE
_add_point(MutationType.REPLACE)

_add_point(MutationType.REPLACE)

# Special handle list traversal to add insertion and deletion.
if isinstance(v, symbolic.List):
Expand Down Expand Up @@ -148,10 +148,13 @@ def mutate(

# Mutating value.
if point.mutation_type == MutationType.REPLACE:
assert point.location, point
value.rebind({
str(point.location): self.node_transform(
point.location, point.old_value, point.parent)})
if point.location:
value.rebind({
str(point.location): self.node_transform(
point.location, point.old_value, point.parent)})
else:
value = self.node_transform(
point.location, point.old_value, point.parent)
elif point.mutation_type == MutationType.INSERT:
assert isinstance(point.parent, symbolic.List), point
assert point.old_value == object_utils.MISSING_VALUE, point
Expand Down
14 changes: 14 additions & 0 deletions pyglove/core/hyper/evolvable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ def test_replace(self):
])
]))

# Mutating at root.
v = evolve(
seed_program(), lambda k, v, p: ReLU(),
weights=lambda mt, k, v, p: 1.0 if p is None else 0.0)
self.assertEqual(
v.mutate(seed_program()),
ReLU()
)

def test_insertion(self):
v = evolve(
seed_program(), lambda k, v, p: ReLU(),
Expand Down Expand Up @@ -161,6 +170,7 @@ def test_mutation_points_and_weights(self):
# NOTE(daiyip): Conv.kernel_size is marked with 'no_mutation', thus
# it should not show here.
self.assertEqual([(p.mutation_type, p.location) for p in points], [
(MutationType.REPLACE, ''),
(MutationType.REPLACE, 'layers'),
(MutationType.INSERT, 'layers[0]'),
(MutationType.DELETE, 'layers[0]'),
Expand Down Expand Up @@ -193,6 +203,7 @@ def test_mutation_points_and_weights_with_honoring_list_size(self):
weights=lambda *x: 1.0)
points, _ = v.mutation_points_and_weights(symbolic.List([1]))
self.assertEqual([(p.mutation_type, p.location) for p in points], [
(MutationType.REPLACE, ''),
(MutationType.INSERT, '[0]'),
(MutationType.DELETE, '[0]'),
(MutationType.REPLACE, '[0]'),
Expand All @@ -204,6 +215,7 @@ def test_mutation_points_and_weights_with_honoring_list_size(self):
points, _ = v.mutation_points_and_weights(
symbolic.List([1, 2], value_spec=value_spec))
self.assertEqual([(p.mutation_type, p.location) for p in points], [
(MutationType.REPLACE, ''),
(MutationType.INSERT, '[0]'),
(MutationType.DELETE, '[0]'),
(MutationType.REPLACE, '[0]'),
Expand All @@ -215,13 +227,15 @@ def test_mutation_points_and_weights_with_honoring_list_size(self):
points, _ = v.mutation_points_and_weights(
symbolic.List([1], value_spec=value_spec))
self.assertEqual([(p.mutation_type, p.location) for p in points], [
(MutationType.REPLACE, ''),
(MutationType.INSERT, '[0]'),
(MutationType.REPLACE, '[0]'),
(MutationType.INSERT, '[1]'),
])
points, _ = v.mutation_points_and_weights(
symbolic.List([1, 2, 3], value_spec=value_spec))
self.assertEqual([(p.mutation_type, p.location) for p in points], [
(MutationType.REPLACE, ''),
(MutationType.DELETE, '[0]'),
(MutationType.REPLACE, '[0]'),
(MutationType.DELETE, '[1]'),
Expand Down

0 comments on commit 73ef9db

Please sign in to comment.