Skip to content

Commit

Permalink
Replace deprecated assertDictContainsSubset in TF-GNN for Python 3.12
Browse files Browse the repository at this point in the history
compatibility, see https://docs.python.org/3/whatsnew/3.12.html#id3

PiperOrigin-RevId: 716582633
  • Loading branch information
arnoegw authored and tensorflower-gardener committed Jan 17, 2025
1 parent b087543 commit 7509061
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tensorflow_gnn/keras/layers/graph_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def testFromConfig(self, location):
edges=dict(edge_set_name="edges"))[location]
kwargs = dict(location_kwarg, feature_name="value", name="test_readout")
config = graph_ops.Readout(**kwargs).get_config()
self.assertDictContainsSubset(kwargs, config)
self.assertEqual(kwargs, {k: v for k, v in config.items() if k in kwargs},
msg="config is expected to contain kwargs as a subset.")

readout = graph_ops.Readout.from_config(config)
self.assertEqual("value", readout.feature_name)
Expand Down Expand Up @@ -273,7 +274,7 @@ def testFromConfig(self):
kwargs = dict(node_set_name="nodes", feature_name="dense",
name="test_readout_first")
config = graph_ops.ReadoutFirstNode(**kwargs).get_config()
self.assertDictContainsSubset(kwargs, config)
self.assertEqual(config, {**config, **kwargs})

readout = graph_ops.ReadoutFirstNode.from_config(config)
self.assertEqual("dense", readout.feature_name)
Expand Down Expand Up @@ -910,7 +911,7 @@ def testFromConfig(self, tag, location, expected):
kwargs = dict(location, tag=tag, feature_name="value",
name="test_broadcast")
config = graph_ops.Broadcast(**kwargs).get_config()
self.assertDictContainsSubset(kwargs, config)
self.assertEqual(config, {**config, **kwargs})

broadcast = graph_ops.Broadcast.from_config(config)
self.assertEqual(tag, broadcast.tag)
Expand Down Expand Up @@ -1190,7 +1191,7 @@ def testFromConfig(self, tag, location, reduce_type, expected):
kwargs = dict(location, reduce_type=reduce_type, tag=tag,
feature_name="value", name="test_pool")
config = graph_ops.Pool(**kwargs).get_config()
self.assertDictContainsSubset(kwargs, config)
self.assertEqual(config, {**config, **kwargs})

pool = graph_ops.Pool.from_config(config)
self.assertEqual(tag, pool.tag)
Expand Down

0 comments on commit 7509061

Please sign in to comment.