Quick, Robot, Draw! turns Google’s Quick, Draw! sketches into normalized sequence data ready for in-context imitation learning with transformer-based diffusion policies, state-space models, and other sequence learners. The pipeline ingests the official .ndjson or .bin releases, preprocesses every sketch into absolute and delta trajectories with pen-state channels, assembles configurable K-shot prompt/query episodes, and stores everything in efficient backends with PyTorch-friendly loaders.
- Consistent geometry: every sketch is centered, scaled into
[-1, 1]^2, and available as both absolute points and cumulative deltas. - Episode-aware: episodes follow the structure
[START, prompt₁, SEP, …, RESET, START, query, STOP]with binary control channels (pen, start, sep, reset, stop) so transformers and diffusion models can consume a single token stream. - High-throughput I/O: supports LMDB, WebDataset shards, or HDF5 for cached sketches/episodes plus deterministic PyTorch
Dataset+ collate utilities. - Inspectable + verifiable: ships with scripts to visualize, profile, and sanity-check the processed cache.
- Python 3.9+
pip install numpy torch lmdb h5py msgpack PyYAML tqdm matplotlib- Install the appropriate PyTorch wheel for your platform/CUDA setup via pytorch.org.
gsutilfor downloading the raw QuickDraw release.
Quick, Robot, Draw! expects the official QuickDraw .ndjson or .bin files to live under a raw_root directory (default raw/). Pull whichever categories/splits you need:
# Install Google Cloud SDK for gsutil if necessary.
mkdir -p raw
gsutil -m cp 'gs://quickdraw_dataset/full/simplified/*.ndjson' raw/
# or selectively:
gsutil cp 'gs://quickdraw_dataset/full/raw/cat.ndjson' raw/You can also download individual files via the Cloud Storage browser, then place them under raw/.
config/data_config.yaml controls preprocessing and storage:
root: "data/" # where processed caches + manifest live
raw_root: "raw/" # where the downloaded .ndjson/.bin files live
backend: "lmdb" # lmdb | webdataset | hdf5
num_prompts: 5 # K-shot size
max_seq_len: 512 # episode token cap
normalize: true # center & scale each sketch
resample:
points: null # optionally force per-stroke point count
augmentations: # applied online during episode sampling
rotation: true
scale: true
translation: true
storage:
compression: "zstd"
shards: 64
max_sketches_per_file: null # cap sketches pulled from each raw file
families: null # optionally whitelist specific categoriesAdjust raw_root/root to match your filesystem. If you only want a subset, place just those files under raw_root or run with --max-files to cap the build pass.
PYTHONPATH=. python scripts/build_dataset.py \
--config config/data_config.yaml \
--num-workers 4 \
--max-files 5 # optional while testingThis will:
- Iterate through every
.ndjson/.binunderraw_root. - Resample strokes (optional), flatten the strokes, and emit pen-up/down markers.
- Normalize each sketch into
[-1, 1]^2and compute(dx, dy)deltas. - Cache both representations plus metadata in the chosen backend (
data/sketches/...). - Write
data/DatasetManifest.jsonwith counts, normalization stats, and per-family split assignments (train/val/test). - Optionally prebuild episodes (
num_prebuilt_episodes) insidedata/episodes/.
Use --force to rebuild even if a manifest already exists, and --max-files to process only the first N raw files on a pass.
Each episode contains K prompt sketches and one query sketch sampled from the same family:
[START, prompt₁, SEP, prompt₂, SEP, …, promptK, SEP, RESET,
START, query, STOP]
Tokens are float vectors of width 7:
| Channel | Description |
|---|---|
| 0–1 | dx, dy deltas |
| 2 | pen state (1 = drawing, 0 = lift) |
| 3 | start flag |
| 4 | separator flag |
| 5 | reset flag |
| 6 | stop flag |
Per-token metadata (family IDs, prompt/query IDs, lengths) accompanies every episode so diffusion/transformer policies can condition on prompts and evaluate queries in-context.
from dataset.loader import QuickDrawEpisodes, QuickDrawEpisodesAbsolute, quickdraw_collate_fn
dataset = QuickDrawEpisodes(
root="data/",
split="train",
K=5,
backend="lmdb",
max_seq_len=512,
augment=True,
)
from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=16, collate_fn=quickdraw_collate_fn)
for batch in loader:
tokens = batch["tokens"] # (B, T, 7)
mask = batch["mask"] # (B, T)
# feed tokens/mask to transformer, diffusion policy, or SSMNeed absolute positions instead of deltas? Use the convenience subclass:
dataset = QuickDrawEpisodesAbsolute(root="data/", split="train", K=5)or pass coordinate_mode="absolute" to QuickDrawEpisodes.
QuickDrawEpisodes assembles episodes lazily from cached sketches, applying deterministic per-worker seeds and optional online augmentations (rotation/scale/translation/jitter). The provided collate_fn pads sequences and emits masks for turnkey batching.
Diffusion transformers that observe the prompts plus the first S query tokens and denoise the next H tokens can use the CollateDiffusionInContext wrapper:
from dataset.loader import QuickDrawEpisodes
from dataset.diffusion import CollateDiffusionInContext
from torch.utils.data import DataLoader
episodes = QuickDrawEpisodes(root="data/", split="train", K=5)
collator = CollateDiffusionInContext(horizon=64) # randomly samples S per episode
loader = DataLoader(episodes, batch_size=16, collate_fn=collator)
batch = next(iter(loader))
tokens = batch["tokens"] # padded episode tokens
context_mask = batch["context_mask"] # prompt + observed query tokens
target_mask = batch["target_mask"] # denoised segment (length ≤ H)The collator uniformly samples how many query tokens to reveal before denoising, anywhere between 0 and the largest value that still leaves H tokens for diffusion. Batch dictionaries now include observed_query_tokens, context_mask, and target_mask while preserving all fields from quickdraw_collate_fn.
For unconditional sketch generation (no prompt context), use the new DiffusionCollator, which treats the entire sketch history as context and provides only the horizon chunk as the diffusion target.
To launch end-to-end training with the DiT diffusion policy implementation in diffusion_policy/, run:
PYTHONPATH=. python diffusion_policy/train_quickdraw.py --data-root data/ --horizon 64
The script wraps QuickDrawEpisodes into the fixed-width tensors required by DiTDiffusionPolicy and trains using a simple AdamW loop.
| Script | Purpose |
|---|---|
scripts/visualize_episode.py |
Sample an episode, plot the concatenated trajectory + per-sketch panels, and save PNGs to figures/. |
scripts/verify_dataset.py |
Validate counts, check for NaNs/shape issues, and sample episodes for sanity. |
scripts/profile_loading.py |
Benchmark DataLoader throughput (episodes/sec, tokens/sec). |
Run the scripts with PYTHONPATH=. so they can import the package modules.
data/
DatasetManifest.json # config + stats
sketches/ # LMDB/WebDataset/HDF5 backend cache
episodes/ # optional prebuilt episodes (same backend)
raw/ # your downloaded QuickDraw files (input only)
figures/ # visualizations from visualize_episode.py
Switching backends only affects how the sketches/ and episodes/ directories are structured—the higher-level APIs stay identical.
The preprocessing + episode builder stack only assumes "family_id" and a list of stroke arrays. To plug in datasets like Omniglot or LASA:
- Implement a raw loader that yields
RawSketchinstances. - Reuse
QuickDrawPreprocessoror subclass it for dataset-specific normalization. - Store the processed sketches through
SketchStorageand useEpisodeBuilder/QuickDrawEpisodesunchanged.
You can train both a BiLSTM (SketchRNN) baseline and a DiT diffusion policy on:
- Unconditional single-class generation: train on one category (set
familiesinconfig/data_config.yamlandK=0to drop prompts). - Multi-class in-context imitation learning: train on episodes with prompts + query across all families (default
K>0).
Unconditional / single-class generation (set families in the data config and K=0):
PYTHONPATH=. python lstm/train_imitation_learning.py \
--config lstm/configs/imitation_learning.py \
--config.data_root data/ \
--config.K 1 \Unconditional / single-class generation (set families in the data config and K=0):
PYTHONPATH=. python diffusion_policy/train_imitation_learning.py \
--config diffusion_policy/configs/imitation_learning.py \
--config.data_root data/ \
--config.K 1 \- The Quick, Draw! dataset is © Google, released under the Creative Commons Attribution 4.0 license—review their terms before redistribution.
- The tooling in Quick, Robot, Draw! is provided under the same license as this repository (see
LICENSE).




