diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..c1ab1b7 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "print.colourScheme": "nnfx" +} \ No newline at end of file diff --git a/meta_dataset/data/pipeline.py b/meta_dataset/data/pipeline.py index 1017f86..f18fdea 100644 --- a/meta_dataset/data/pipeline.py +++ b/meta_dataset/data/pipeline.py @@ -373,7 +373,8 @@ def make_one_source_episode_pipeline(dataset_spec, image_size=None, num_to_take=None, ignore_hierarchy_probability=0.0, - simclr_episode_fraction=0.0): + simclr_episode_fraction=0.0, + episode_sampling_seed=None): """Returns a pipeline emitting data from one single source as Episodes. Args: @@ -428,7 +429,8 @@ def make_one_source_episode_pipeline(dataset_spec, use_dag_hierarchy=use_dag_ontology, use_bilevel_hierarchy=use_bilevel_ontology, use_all_classes=use_all_classes, - ignore_hierarchy_probability=ignore_hierarchy_probability) + ignore_hierarchy_probability=ignore_hierarchy_probability, + episode_sampling_seed=episode_sampling_seed) dataset = episode_reader.create_dataset_input_pipeline(sampler, pool=pool) # Episodes coming out of `dataset` contain flushed examples and are internally # padded with placeholder examples. `process_episode` discards flushed @@ -463,6 +465,7 @@ def make_multisource_episode_pipeline(dataset_spec_list, image_size=None, num_to_take=None, source_sampling_seed=None, + episode_sampling_seed=None, simclr_episode_fraction=0.0): """Returns a pipeline emitting data from multiple sources as Episodes. @@ -492,6 +495,7 @@ def make_multisource_episode_pipeline(dataset_spec_list, length must be the same as len(dataset_spec). If None, no restrictions are applied to any dataset and all data per class is used. source_sampling_seed: random seed for source sampling. + episode_sampling_seed: random seed for episode sampling. simclr_episode_fraction: Float, fraction of episodes that will be converted to SimCLR Episodes as described in the CrossTransformers paper. @@ -524,7 +528,8 @@ def make_multisource_episode_pipeline(dataset_spec_list, episode_descr_config, pool=pool, use_dag_hierarchy=use_dag_ontology, - use_bilevel_hierarchy=use_bilevel_ontology) + use_bilevel_hierarchy=use_bilevel_ontology, + episode_sampling_seed=episode_sampling_seed) dataset = episode_reader.create_dataset_input_pipeline(sampler, pool=pool) # Create a dataset to zip with the above for identifying the source. source_id_dataset = tf.data.Dataset.from_tensors(source_id).repeat() diff --git a/meta_dataset/data/sampling.py b/meta_dataset/data/sampling.py index 2e96ec6..7f6883d 100644 --- a/meta_dataset/data/sampling.py +++ b/meta_dataset/data/sampling.py @@ -231,7 +231,8 @@ def __init__(self, use_dag_hierarchy=False, use_bilevel_hierarchy=False, use_all_classes=False, - ignore_hierarchy_probability=0.0): + ignore_hierarchy_probability=0.0, + episode_sampling_seed=None): """Initializes an EpisodeDescriptionSampler.episode_config. Args: @@ -251,6 +252,8 @@ def __init__(self, ignore_hierarchy_probability: Float, if using a hierarchy, this flag makes the sampler ignore the hierarchy for this proportion of episodes and instead sample categories uniformly. + episode_sampling_seed: random seed for making episode description sampling + deterministic within individual data sources Raises: RuntimeError: if required parameters are missing. @@ -259,8 +262,13 @@ def __init__(self, # Each instance has its own RNG which is seeded from the module-level RNG, # which makes episode description sampling deterministic within individual # data sources. - self._rng = np.random.RandomState( - seed=RNG.randint(0, 2**32, size=None, dtype='uint32')) + if episode_sampling_seed == None: + self._rng = np.random.RandomState( + seed=RNG.randint(0, 2**32, size=None, dtype='uint32')) + else: + self._rng = np.random.RandomState( + seed=episode_sampling_seed) + self.dataset_spec = dataset_spec self.split = split self.pool = pool diff --git a/meta_dataset/learn/gin/best/prototypical_imagenet.gin b/meta_dataset/learn/gin/best/prototypical_imagenet.gin index 7af0958..e575808 100644 --- a/meta_dataset/learn/gin/best/prototypical_imagenet.gin +++ b/meta_dataset/learn/gin/best/prototypical_imagenet.gin @@ -2,7 +2,7 @@ include 'meta_dataset/learn/gin/setups/imagenet.gin' include 'meta_dataset/learn/gin/learners/prototypical_config.gin' # Backbone hypers. -include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin' +include 'meta_dataset/learn/gin/best/pretrained_resnet.gin' # Data hypers. DataConfig.image_height = 126