diff --git a/.gitignore b/.gitignore index 1551243..c14121e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ _test _data docker* example/image/test_* +bk* +test* # Byte-compiled / optimized / DLL files __pycache__/ @@ -130,7 +132,7 @@ celerybeat.pid *.sage.py # Environments -.env +*.env .venv env/ venv/ diff --git a/README.md b/README.md index 82d485d..114cb87 100644 --- a/README.md +++ b/README.md @@ -5,14 +5,17 @@ This repo is a Python library to **generate differentially private (DP) syntheti * **Differentially Private Synthetic Data via Foundation Model APIs 1: Images** [[paper (ICLR 2024)]](https://openreview.net/forum?id=YEhQs8POIo) [[paper (arxiv)](https://arxiv.org/abs/2305.15560)] **Authors:** [[Zinan Lin](https://zinanlin.me/)], [[Sivakanth Gopi](https://www.microsoft.com/en-us/research/people/sigopi/)], [[Janardhan Kulkarni](https://www.microsoft.com/en-us/research/people/jakul/)], [[Harsha Nori](https://www.microsoft.com/en-us/research/people/hanori/)], [[Sergey Yekhanin](http://www.yekhanin.org/)] - +* **Differentially Private Synthetic Data via Foundation Model APIs 2: Text** + [[paper (ICML 2024 Spotlight)]](https://proceedings.mlr.press/v235/xie24g.html) [[paper (arxiv)](https://arxiv.org/abs/2403.01749)] [[website](https://alphapav.github.io/augpe-dpapitext)] + **Authors:** [[Chulin Xie](https://alphapav.github.io/)], [[Zinan Lin](https://zinanlin.me/)], [[Arturs Backurs](https://www.mit.edu/~backurs/)], [[Sivakanth Gopi](https://www.microsoft.com/en-us/research/people/sigopi/)], [[Da Yu](https://dayu11.github.io/)], [[Huseyin Inan](https://www.microsoft.com/en-us/research/people/huinan/)], [[Harsha Nori](https://www.microsoft.com/en-us/research/people/hanori/)], [[Haotian Jiang](https://jhtdavid96.wixsite.com/jianghaotian)], [[Huishuai Zhang](https://huishuai-git.github.io/)], [[Yin Tat Lee](https://yintat.com/)], [[Bo Li](https://aisecure.github.io/)], [[Sergey Yekhanin](http://www.yekhanin.org/)] ## Documentation Please refer to the [documentation](https://microsoft.github.io/DPSDA/) for more details, including the installation instructions, usage, and examples. ## News -* `11/21/2024`: The refactored codebase for image generation has been released. It is completely refactored to be more modular and easier to use and extend. The code originally published with the [paper](https://arxiv.org/abs/2305.15560) has been moved to the [deprecated](https://github.com/microsoft/DPSDA/tree/deprecated) branch, which is no longer maintained. +* `1/8/2025`: **Text generation** based on the paper `Differentially Private Synthetic Data via Foundation Model APIs 2: Text` has been integrated into the library! If you want to reproduce the results in the [paper](https://arxiv.org/abs/2403.01749), please refer to [our original codebase](https://github.com/AI-secure/aug-pe). +* `11/21/2024`: The refactored codebase for **image generation** based on the paper `Differentially Private Synthetic Data via Foundation Model APIs 1: Images` has been released! It is completely refactored to be more modular and easier to use and extend. The code originally published with the [paper](https://arxiv.org/abs/2305.15560) has been moved to the [deprecated](https://github.com/microsoft/DPSDA/tree/deprecated) branch in this repository, which is no longer maintained. ## Contributing diff --git a/doc/build_autodoc.sh b/doc/build_autodoc.sh index d20cf2f..5c08694 100644 --- a/doc/build_autodoc.sh +++ b/doc/build_autodoc.sh @@ -1,2 +1,2 @@ -sphinx-apidoc -f --module-first -d 3 -o source/api ../pe +sphinx-apidoc -e -f --module-first -d 7 -o source/api ../pe ../pe/*/test* ../pe/*/bk* make clean html \ No newline at end of file diff --git a/doc/requirements.txt b/doc/requirements.txt index fe5fd45..c6ebe82 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,3 +1,3 @@ sphinx_rtd_theme==3.0.2 -private-evolution[image,dev] @ git+https://github.com/microsoft/DPSDA.git +private-evolution[image,text,dev] @ git+https://github.com/microsoft/DPSDA.git faiss-gpu \ No newline at end of file diff --git a/doc/source/api/api.rst b/doc/source/api/api.rst index fc98e25..f2f791a 100644 --- a/doc/source/api/api.rst +++ b/doc/source/api/api.rst @@ -3,7 +3,7 @@ API Reference .. toctree:: - :maxdepth: 3 + :maxdepth: 7 :caption: Contents: modules diff --git a/doc/source/api/modules.rst b/doc/source/api/modules.rst index a446077..086122a 100644 --- a/doc/source/api/modules.rst +++ b/doc/source/api/modules.rst @@ -2,6 +2,6 @@ pe == .. toctree:: - :maxdepth: 3 + :maxdepth: 7 pe diff --git a/doc/source/api/pe.api.api.rst b/doc/source/api/pe.api.api.rst new file mode 100644 index 0000000..2dac82f --- /dev/null +++ b/doc/source/api/pe.api.api.rst @@ -0,0 +1,7 @@ +pe.api.api module +================= + +.. automodule:: pe.api.api + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.api.image.improved_diffusion_api.rst b/doc/source/api/pe.api.image.improved_diffusion_api.rst new file mode 100644 index 0000000..91954c4 --- /dev/null +++ b/doc/source/api/pe.api.image.improved_diffusion_api.rst @@ -0,0 +1,7 @@ +pe.api.image.improved\_diffusion\_api module +============================================ + +.. automodule:: pe.api.image.improved_diffusion_api + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.api.image.improved_diffusion_lib.gaussian_diffusion.rst b/doc/source/api/pe.api.image.improved_diffusion_lib.gaussian_diffusion.rst new file mode 100644 index 0000000..e5a730f --- /dev/null +++ b/doc/source/api/pe.api.image.improved_diffusion_lib.gaussian_diffusion.rst @@ -0,0 +1,7 @@ +pe.api.image.improved\_diffusion\_lib.gaussian\_diffusion module +================================================================ + +.. automodule:: pe.api.image.improved_diffusion_lib.gaussian_diffusion + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.api.image.improved_diffusion_lib.rst b/doc/source/api/pe.api.image.improved_diffusion_lib.rst index 841044c..8a40ef0 100644 --- a/doc/source/api/pe.api.image.improved_diffusion_lib.rst +++ b/doc/source/api/pe.api.image.improved_diffusion_lib.rst @@ -9,18 +9,8 @@ pe.api.image.improved\_diffusion\_lib package Submodules ---------- -pe.api.image.improved\_diffusion\_lib.gaussian\_diffusion module ----------------------------------------------------------------- +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.api.image.improved_diffusion_lib.gaussian_diffusion - :members: - :undoc-members: - :show-inheritance: - -pe.api.image.improved\_diffusion\_lib.unet module -------------------------------------------------- - -.. automodule:: pe.api.image.improved_diffusion_lib.unet - :members: - :undoc-members: - :show-inheritance: + pe.api.image.improved_diffusion_lib.gaussian_diffusion + pe.api.image.improved_diffusion_lib.unet diff --git a/doc/source/api/pe.api.image.improved_diffusion_lib.unet.rst b/doc/source/api/pe.api.image.improved_diffusion_lib.unet.rst new file mode 100644 index 0000000..59564f6 --- /dev/null +++ b/doc/source/api/pe.api.image.improved_diffusion_lib.unet.rst @@ -0,0 +1,7 @@ +pe.api.image.improved\_diffusion\_lib.unet module +================================================= + +.. automodule:: pe.api.image.improved_diffusion_lib.unet + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.api.image.rst b/doc/source/api/pe.api.image.rst index c53f742..1eb8ca6 100644 --- a/doc/source/api/pe.api.image.rst +++ b/doc/source/api/pe.api.image.rst @@ -10,25 +10,15 @@ Subpackages ----------- .. toctree:: - :maxdepth: 3 + :maxdepth: 7 pe.api.image.improved_diffusion_lib Submodules ---------- -pe.api.image.improved\_diffusion\_api module --------------------------------------------- - -.. automodule:: pe.api.image.improved_diffusion_api - :members: - :undoc-members: - :show-inheritance: - -pe.api.image.stable\_diffusion\_api module ------------------------------------------- +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.api.image.stable_diffusion_api - :members: - :undoc-members: - :show-inheritance: + pe.api.image.improved_diffusion_api + pe.api.image.stable_diffusion_api diff --git a/doc/source/api/pe.api.image.stable_diffusion_api.rst b/doc/source/api/pe.api.image.stable_diffusion_api.rst new file mode 100644 index 0000000..0af9591 --- /dev/null +++ b/doc/source/api/pe.api.image.stable_diffusion_api.rst @@ -0,0 +1,7 @@ +pe.api.image.stable\_diffusion\_api module +========================================== + +.. automodule:: pe.api.image.stable_diffusion_api + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.api.rst b/doc/source/api/pe.api.rst index a2f2616..9febdcd 100644 --- a/doc/source/api/pe.api.rst +++ b/doc/source/api/pe.api.rst @@ -10,25 +10,16 @@ Subpackages ----------- .. toctree:: - :maxdepth: 3 + :maxdepth: 7 pe.api.image + pe.api.text Submodules ---------- -pe.api.api module ------------------ - -.. automodule:: pe.api.api - :members: - :undoc-members: - :show-inheritance: - -pe.api.util module ------------------- +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.api.util - :members: - :undoc-members: - :show-inheritance: + pe.api.api + pe.api.util diff --git a/doc/source/api/pe.api.text.llm_augpe_api.rst b/doc/source/api/pe.api.text.llm_augpe_api.rst new file mode 100644 index 0000000..7015c3a --- /dev/null +++ b/doc/source/api/pe.api.text.llm_augpe_api.rst @@ -0,0 +1,7 @@ +pe.api.text.llm\_augpe\_api module +================================== + +.. automodule:: pe.api.text.llm_augpe_api + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.api.text.rst b/doc/source/api/pe.api.text.rst new file mode 100644 index 0000000..372734b --- /dev/null +++ b/doc/source/api/pe.api.text.rst @@ -0,0 +1,15 @@ +pe.api.text package +=================== + +.. automodule:: pe.api.text + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 7 + + pe.api.text.llm_augpe_api diff --git a/doc/source/api/pe.api.util.rst b/doc/source/api/pe.api.util.rst new file mode 100644 index 0000000..10bc89c --- /dev/null +++ b/doc/source/api/pe.api.util.rst @@ -0,0 +1,7 @@ +pe.api.util module +================== + +.. automodule:: pe.api.util + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.callback.callback.rst b/doc/source/api/pe.callback.callback.rst new file mode 100644 index 0000000..843fb5c --- /dev/null +++ b/doc/source/api/pe.callback.callback.rst @@ -0,0 +1,7 @@ +pe.callback.callback module +=========================== + +.. automodule:: pe.callback.callback + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.callback.common.compute_fid.rst b/doc/source/api/pe.callback.common.compute_fid.rst new file mode 100644 index 0000000..8b11656 --- /dev/null +++ b/doc/source/api/pe.callback.common.compute_fid.rst @@ -0,0 +1,7 @@ +pe.callback.common.compute\_fid module +====================================== + +.. automodule:: pe.callback.common.compute_fid + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.callback.common.rst b/doc/source/api/pe.callback.common.rst index 9c387d0..4b2ef1c 100644 --- a/doc/source/api/pe.callback.common.rst +++ b/doc/source/api/pe.callback.common.rst @@ -9,18 +9,8 @@ pe.callback.common package Submodules ---------- -pe.callback.common.compute\_fid module --------------------------------------- +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.callback.common.compute_fid - :members: - :undoc-members: - :show-inheritance: - -pe.callback.common.save\_checkpoints module -------------------------------------------- - -.. automodule:: pe.callback.common.save_checkpoints - :members: - :undoc-members: - :show-inheritance: + pe.callback.common.compute_fid + pe.callback.common.save_checkpoints diff --git a/doc/source/api/pe.callback.common.save_checkpoints.rst b/doc/source/api/pe.callback.common.save_checkpoints.rst new file mode 100644 index 0000000..5d192a9 --- /dev/null +++ b/doc/source/api/pe.callback.common.save_checkpoints.rst @@ -0,0 +1,7 @@ +pe.callback.common.save\_checkpoints module +=========================================== + +.. automodule:: pe.callback.common.save_checkpoints + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.callback.image.rst b/doc/source/api/pe.callback.image.rst index 784912c..0492236 100644 --- a/doc/source/api/pe.callback.image.rst +++ b/doc/source/api/pe.callback.image.rst @@ -9,18 +9,8 @@ pe.callback.image package Submodules ---------- -pe.callback.image.sample\_images module ---------------------------------------- +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.callback.image.sample_images - :members: - :undoc-members: - :show-inheritance: - -pe.callback.image.save\_all\_images module ------------------------------------------- - -.. automodule:: pe.callback.image.save_all_images - :members: - :undoc-members: - :show-inheritance: + pe.callback.image.sample_images + pe.callback.image.save_all_images diff --git a/doc/source/api/pe.callback.image.sample_images.rst b/doc/source/api/pe.callback.image.sample_images.rst new file mode 100644 index 0000000..7a0f9ac --- /dev/null +++ b/doc/source/api/pe.callback.image.sample_images.rst @@ -0,0 +1,7 @@ +pe.callback.image.sample\_images module +======================================= + +.. automodule:: pe.callback.image.sample_images + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.callback.image.save_all_images.rst b/doc/source/api/pe.callback.image.save_all_images.rst new file mode 100644 index 0000000..6c74b1e --- /dev/null +++ b/doc/source/api/pe.callback.image.save_all_images.rst @@ -0,0 +1,7 @@ +pe.callback.image.save\_all\_images module +========================================== + +.. automodule:: pe.callback.image.save_all_images + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.callback.rst b/doc/source/api/pe.callback.rst index ae740bf..32b087a 100644 --- a/doc/source/api/pe.callback.rst +++ b/doc/source/api/pe.callback.rst @@ -10,18 +10,16 @@ Subpackages ----------- .. toctree:: - :maxdepth: 3 + :maxdepth: 7 pe.callback.common pe.callback.image + pe.callback.text Submodules ---------- -pe.callback.callback module ---------------------------- +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.callback.callback - :members: - :undoc-members: - :show-inheritance: + pe.callback.callback diff --git a/doc/source/api/pe.callback.text.rst b/doc/source/api/pe.callback.text.rst new file mode 100644 index 0000000..41e54ea --- /dev/null +++ b/doc/source/api/pe.callback.text.rst @@ -0,0 +1,15 @@ +pe.callback.text package +======================== + +.. automodule:: pe.callback.text + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 7 + + pe.callback.text.save_text_to_csv diff --git a/doc/source/api/pe.callback.text.save_text_to_csv.rst b/doc/source/api/pe.callback.text.save_text_to_csv.rst new file mode 100644 index 0000000..5f2a487 --- /dev/null +++ b/doc/source/api/pe.callback.text.save_text_to_csv.rst @@ -0,0 +1,7 @@ +pe.callback.text.save\_text\_to\_csv module +=========================================== + +.. automodule:: pe.callback.text.save_text_to_csv + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.constant.data.rst b/doc/source/api/pe.constant.data.rst new file mode 100644 index 0000000..38cd6a8 --- /dev/null +++ b/doc/source/api/pe.constant.data.rst @@ -0,0 +1,7 @@ +pe.constant.data module +======================= + +.. automodule:: pe.constant.data + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.constant.rst b/doc/source/api/pe.constant.rst index 4f56470..ae269cc 100644 --- a/doc/source/api/pe.constant.rst +++ b/doc/source/api/pe.constant.rst @@ -9,10 +9,7 @@ pe.constant package Submodules ---------- -pe.constant.data module ------------------------ +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.constant.data - :members: - :undoc-members: - :show-inheritance: + pe.constant.data diff --git a/doc/source/api/pe.data.data.rst b/doc/source/api/pe.data.data.rst new file mode 100644 index 0000000..84b524e --- /dev/null +++ b/doc/source/api/pe.data.data.rst @@ -0,0 +1,7 @@ +pe.data.data module +=================== + +.. automodule:: pe.data.data + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.data.image.camelyon17.rst b/doc/source/api/pe.data.image.camelyon17.rst new file mode 100644 index 0000000..d287ae8 --- /dev/null +++ b/doc/source/api/pe.data.image.camelyon17.rst @@ -0,0 +1,7 @@ +pe.data.image.camelyon17 module +=============================== + +.. automodule:: pe.data.image.camelyon17 + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.data.image.cat.rst b/doc/source/api/pe.data.image.cat.rst new file mode 100644 index 0000000..5f205b7 --- /dev/null +++ b/doc/source/api/pe.data.image.cat.rst @@ -0,0 +1,7 @@ +pe.data.image.cat module +======================== + +.. automodule:: pe.data.image.cat + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.data.image.cifar10.rst b/doc/source/api/pe.data.image.cifar10.rst new file mode 100644 index 0000000..7a7e0e2 --- /dev/null +++ b/doc/source/api/pe.data.image.cifar10.rst @@ -0,0 +1,7 @@ +pe.data.image.cifar10 module +============================ + +.. automodule:: pe.data.image.cifar10 + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.data.image.image.rst b/doc/source/api/pe.data.image.image.rst new file mode 100644 index 0000000..7adf323 --- /dev/null +++ b/doc/source/api/pe.data.image.image.rst @@ -0,0 +1,7 @@ +pe.data.image.image module +========================== + +.. automodule:: pe.data.image.image + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.data.image.rst b/doc/source/api/pe.data.image.rst index 177b224..0033fda 100644 --- a/doc/source/api/pe.data.image.rst +++ b/doc/source/api/pe.data.image.rst @@ -9,34 +9,10 @@ pe.data.image package Submodules ---------- -pe.data.image.camelyon17 module -------------------------------- +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.data.image.camelyon17 - :members: - :undoc-members: - :show-inheritance: - -pe.data.image.cat module ------------------------- - -.. automodule:: pe.data.image.cat - :members: - :undoc-members: - :show-inheritance: - -pe.data.image.cifar10 module ----------------------------- - -.. automodule:: pe.data.image.cifar10 - :members: - :undoc-members: - :show-inheritance: - -pe.data.image.image module --------------------------- - -.. automodule:: pe.data.image.image - :members: - :undoc-members: - :show-inheritance: + pe.data.image.camelyon17 + pe.data.image.cat + pe.data.image.cifar10 + pe.data.image.image diff --git a/doc/source/api/pe.data.rst b/doc/source/api/pe.data.rst index 27d092f..31e7921 100644 --- a/doc/source/api/pe.data.rst +++ b/doc/source/api/pe.data.rst @@ -10,17 +10,15 @@ Subpackages ----------- .. toctree:: - :maxdepth: 3 + :maxdepth: 7 pe.data.image + pe.data.text Submodules ---------- -pe.data.data module -------------------- +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.data.data - :members: - :undoc-members: - :show-inheritance: + pe.data.data diff --git a/doc/source/api/pe.data.text.openreview.rst b/doc/source/api/pe.data.text.openreview.rst new file mode 100644 index 0000000..6b300e1 --- /dev/null +++ b/doc/source/api/pe.data.text.openreview.rst @@ -0,0 +1,7 @@ +pe.data.text.openreview module +============================== + +.. automodule:: pe.data.text.openreview + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.data.text.pubmed.rst b/doc/source/api/pe.data.text.pubmed.rst new file mode 100644 index 0000000..aa04e45 --- /dev/null +++ b/doc/source/api/pe.data.text.pubmed.rst @@ -0,0 +1,7 @@ +pe.data.text.pubmed module +========================== + +.. automodule:: pe.data.text.pubmed + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.data.text.rst b/doc/source/api/pe.data.text.rst new file mode 100644 index 0000000..fae551f --- /dev/null +++ b/doc/source/api/pe.data.text.rst @@ -0,0 +1,18 @@ +pe.data.text package +==================== + +.. automodule:: pe.data.text + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 7 + + pe.data.text.openreview + pe.data.text.pubmed + pe.data.text.text_csv + pe.data.text.yelp diff --git a/doc/source/api/pe.data.text.text_csv.rst b/doc/source/api/pe.data.text.text_csv.rst new file mode 100644 index 0000000..6d22869 --- /dev/null +++ b/doc/source/api/pe.data.text.text_csv.rst @@ -0,0 +1,7 @@ +pe.data.text.text\_csv module +============================= + +.. automodule:: pe.data.text.text_csv + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.data.text.yelp.rst b/doc/source/api/pe.data.text.yelp.rst new file mode 100644 index 0000000..f6a091a --- /dev/null +++ b/doc/source/api/pe.data.text.yelp.rst @@ -0,0 +1,7 @@ +pe.data.text.yelp module +======================== + +.. automodule:: pe.data.text.yelp + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.dp.dp.rst b/doc/source/api/pe.dp.dp.rst new file mode 100644 index 0000000..5df4618 --- /dev/null +++ b/doc/source/api/pe.dp.dp.rst @@ -0,0 +1,7 @@ +pe.dp.dp module +=============== + +.. automodule:: pe.dp.dp + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.dp.gaussian.rst b/doc/source/api/pe.dp.gaussian.rst new file mode 100644 index 0000000..35ab547 --- /dev/null +++ b/doc/source/api/pe.dp.gaussian.rst @@ -0,0 +1,7 @@ +pe.dp.gaussian module +===================== + +.. automodule:: pe.dp.gaussian + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.dp.rst b/doc/source/api/pe.dp.rst index 6473b69..8086b20 100644 --- a/doc/source/api/pe.dp.rst +++ b/doc/source/api/pe.dp.rst @@ -9,18 +9,8 @@ pe.dp package Submodules ---------- -pe.dp.dp module ---------------- +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.dp.dp - :members: - :undoc-members: - :show-inheritance: - -pe.dp.gaussian module ---------------------- - -.. automodule:: pe.dp.gaussian - :members: - :undoc-members: - :show-inheritance: + pe.dp.dp + pe.dp.gaussian diff --git a/doc/source/api/pe.embedding.embedding.rst b/doc/source/api/pe.embedding.embedding.rst new file mode 100644 index 0000000..ab96d4a --- /dev/null +++ b/doc/source/api/pe.embedding.embedding.rst @@ -0,0 +1,7 @@ +pe.embedding.embedding module +============================= + +.. automodule:: pe.embedding.embedding + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.embedding.image.inception.rst b/doc/source/api/pe.embedding.image.inception.rst new file mode 100644 index 0000000..4424515 --- /dev/null +++ b/doc/source/api/pe.embedding.image.inception.rst @@ -0,0 +1,7 @@ +pe.embedding.image.inception module +=================================== + +.. automodule:: pe.embedding.image.inception + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.embedding.image.rst b/doc/source/api/pe.embedding.image.rst index 1777711..d470740 100644 --- a/doc/source/api/pe.embedding.image.rst +++ b/doc/source/api/pe.embedding.image.rst @@ -9,10 +9,7 @@ pe.embedding.image package Submodules ---------- -pe.embedding.image.inception module ------------------------------------ +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.embedding.image.inception - :members: - :undoc-members: - :show-inheritance: + pe.embedding.image.inception diff --git a/doc/source/api/pe.embedding.rst b/doc/source/api/pe.embedding.rst index fbeee54..08c4f69 100644 --- a/doc/source/api/pe.embedding.rst +++ b/doc/source/api/pe.embedding.rst @@ -10,17 +10,15 @@ Subpackages ----------- .. toctree:: - :maxdepth: 3 + :maxdepth: 7 pe.embedding.image + pe.embedding.text Submodules ---------- -pe.embedding.embedding module ------------------------------ +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.embedding.embedding - :members: - :undoc-members: - :show-inheritance: + pe.embedding.embedding diff --git a/doc/source/api/pe.embedding.text.rst b/doc/source/api/pe.embedding.text.rst new file mode 100644 index 0000000..8aa2b9a --- /dev/null +++ b/doc/source/api/pe.embedding.text.rst @@ -0,0 +1,15 @@ +pe.embedding.text package +========================= + +.. automodule:: pe.embedding.text + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 7 + + pe.embedding.text.sentence_transformer diff --git a/doc/source/api/pe.embedding.text.sentence_transformer.rst b/doc/source/api/pe.embedding.text.sentence_transformer.rst new file mode 100644 index 0000000..3a4bd3e --- /dev/null +++ b/doc/source/api/pe.embedding.text.sentence_transformer.rst @@ -0,0 +1,7 @@ +pe.embedding.text.sentence\_transformer module +============================================== + +.. automodule:: pe.embedding.text.sentence_transformer + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.histogram.histogram.rst b/doc/source/api/pe.histogram.histogram.rst new file mode 100644 index 0000000..25abd20 --- /dev/null +++ b/doc/source/api/pe.histogram.histogram.rst @@ -0,0 +1,7 @@ +pe.histogram.histogram module +============================= + +.. automodule:: pe.histogram.histogram + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.histogram.nearest_neighbor_backend.auto.rst b/doc/source/api/pe.histogram.nearest_neighbor_backend.auto.rst new file mode 100644 index 0000000..6edf65c --- /dev/null +++ b/doc/source/api/pe.histogram.nearest_neighbor_backend.auto.rst @@ -0,0 +1,7 @@ +pe.histogram.nearest\_neighbor\_backend.auto module +=================================================== + +.. automodule:: pe.histogram.nearest_neighbor_backend.auto + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.histogram.nearest_neighbor_backend.faiss.rst b/doc/source/api/pe.histogram.nearest_neighbor_backend.faiss.rst new file mode 100644 index 0000000..edd84b5 --- /dev/null +++ b/doc/source/api/pe.histogram.nearest_neighbor_backend.faiss.rst @@ -0,0 +1,7 @@ +pe.histogram.nearest\_neighbor\_backend.faiss module +==================================================== + +.. automodule:: pe.histogram.nearest_neighbor_backend.faiss + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.histogram.nearest_neighbor_backend.rst b/doc/source/api/pe.histogram.nearest_neighbor_backend.rst index ac831ec..f015ddb 100644 --- a/doc/source/api/pe.histogram.nearest_neighbor_backend.rst +++ b/doc/source/api/pe.histogram.nearest_neighbor_backend.rst @@ -9,26 +9,9 @@ pe.histogram.nearest\_neighbor\_backend package Submodules ---------- -pe.histogram.nearest\_neighbor\_backend.auto module ---------------------------------------------------- +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.histogram.nearest_neighbor_backend.auto - :members: - :undoc-members: - :show-inheritance: - -pe.histogram.nearest\_neighbor\_backend.faiss module ----------------------------------------------------- - -.. automodule:: pe.histogram.nearest_neighbor_backend.faiss - :members: - :undoc-members: - :show-inheritance: - -pe.histogram.nearest\_neighbor\_backend.sklearn module ------------------------------------------------------- - -.. automodule:: pe.histogram.nearest_neighbor_backend.sklearn - :members: - :undoc-members: - :show-inheritance: + pe.histogram.nearest_neighbor_backend.auto + pe.histogram.nearest_neighbor_backend.faiss + pe.histogram.nearest_neighbor_backend.sklearn diff --git a/doc/source/api/pe.histogram.nearest_neighbor_backend.sklearn.rst b/doc/source/api/pe.histogram.nearest_neighbor_backend.sklearn.rst new file mode 100644 index 0000000..e7e5ab7 --- /dev/null +++ b/doc/source/api/pe.histogram.nearest_neighbor_backend.sklearn.rst @@ -0,0 +1,7 @@ +pe.histogram.nearest\_neighbor\_backend.sklearn module +====================================================== + +.. automodule:: pe.histogram.nearest_neighbor_backend.sklearn + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.histogram.nearest_neighbors.rst b/doc/source/api/pe.histogram.nearest_neighbors.rst new file mode 100644 index 0000000..a5ca0ab --- /dev/null +++ b/doc/source/api/pe.histogram.nearest_neighbors.rst @@ -0,0 +1,7 @@ +pe.histogram.nearest\_neighbors module +====================================== + +.. automodule:: pe.histogram.nearest_neighbors + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.histogram.rst b/doc/source/api/pe.histogram.rst index 2316f2d..39166f9 100644 --- a/doc/source/api/pe.histogram.rst +++ b/doc/source/api/pe.histogram.rst @@ -10,25 +10,15 @@ Subpackages ----------- .. toctree:: - :maxdepth: 3 + :maxdepth: 7 pe.histogram.nearest_neighbor_backend Submodules ---------- -pe.histogram.histogram module ------------------------------ - -.. automodule:: pe.histogram.histogram - :members: - :undoc-members: - :show-inheritance: - -pe.histogram.nearest\_neighbors module --------------------------------------- +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.histogram.nearest_neighbors - :members: - :undoc-members: - :show-inheritance: + pe.histogram.histogram + pe.histogram.nearest_neighbors diff --git a/doc/source/api/pe.llm.azure_openai.rst b/doc/source/api/pe.llm.azure_openai.rst new file mode 100644 index 0000000..c97bffe --- /dev/null +++ b/doc/source/api/pe.llm.azure_openai.rst @@ -0,0 +1,7 @@ +pe.llm.azure\_openai module +=========================== + +.. automodule:: pe.llm.azure_openai + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.llm.huggingface.huggingface.rst b/doc/source/api/pe.llm.huggingface.huggingface.rst new file mode 100644 index 0000000..a0acca9 --- /dev/null +++ b/doc/source/api/pe.llm.huggingface.huggingface.rst @@ -0,0 +1,7 @@ +pe.llm.huggingface.huggingface module +===================================== + +.. automodule:: pe.llm.huggingface.huggingface + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.llm.huggingface.register_fastchat.gpt2.rst b/doc/source/api/pe.llm.huggingface.register_fastchat.gpt2.rst new file mode 100644 index 0000000..df6b68a --- /dev/null +++ b/doc/source/api/pe.llm.huggingface.register_fastchat.gpt2.rst @@ -0,0 +1,7 @@ +pe.llm.huggingface.register\_fastchat.gpt2 module +================================================= + +.. automodule:: pe.llm.huggingface.register_fastchat.gpt2 + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.llm.huggingface.register_fastchat.rst b/doc/source/api/pe.llm.huggingface.register_fastchat.rst new file mode 100644 index 0000000..0e3f33c --- /dev/null +++ b/doc/source/api/pe.llm.huggingface.register_fastchat.rst @@ -0,0 +1,15 @@ +pe.llm.huggingface.register\_fastchat package +============================================= + +.. automodule:: pe.llm.huggingface.register_fastchat + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 7 + + pe.llm.huggingface.register_fastchat.gpt2 diff --git a/doc/source/api/pe.llm.huggingface.rst b/doc/source/api/pe.llm.huggingface.rst new file mode 100644 index 0000000..1032504 --- /dev/null +++ b/doc/source/api/pe.llm.huggingface.rst @@ -0,0 +1,23 @@ +pe.llm.huggingface package +========================== + +.. automodule:: pe.llm.huggingface + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 7 + + pe.llm.huggingface.register_fastchat + +Submodules +---------- + +.. toctree:: + :maxdepth: 7 + + pe.llm.huggingface.huggingface diff --git a/doc/source/api/pe.llm.llm.rst b/doc/source/api/pe.llm.llm.rst new file mode 100644 index 0000000..2e2ea48 --- /dev/null +++ b/doc/source/api/pe.llm.llm.rst @@ -0,0 +1,7 @@ +pe.llm.llm module +================= + +.. automodule:: pe.llm.llm + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.llm.openai.rst b/doc/source/api/pe.llm.openai.rst new file mode 100644 index 0000000..e771bdd --- /dev/null +++ b/doc/source/api/pe.llm.openai.rst @@ -0,0 +1,7 @@ +pe.llm.openai module +==================== + +.. automodule:: pe.llm.openai + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.llm.request.rst b/doc/source/api/pe.llm.request.rst new file mode 100644 index 0000000..10f7d0b --- /dev/null +++ b/doc/source/api/pe.llm.request.rst @@ -0,0 +1,7 @@ +pe.llm.request module +===================== + +.. automodule:: pe.llm.request + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.llm.rst b/doc/source/api/pe.llm.rst new file mode 100644 index 0000000..f3ef261 --- /dev/null +++ b/doc/source/api/pe.llm.rst @@ -0,0 +1,26 @@ +pe.llm package +============== + +.. automodule:: pe.llm + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 7 + + pe.llm.huggingface + +Submodules +---------- + +.. toctree:: + :maxdepth: 7 + + pe.llm.azure_openai + pe.llm.llm + pe.llm.openai + pe.llm.request diff --git a/doc/source/api/pe.logger.csv_print.rst b/doc/source/api/pe.logger.csv_print.rst new file mode 100644 index 0000000..4165cce --- /dev/null +++ b/doc/source/api/pe.logger.csv_print.rst @@ -0,0 +1,7 @@ +pe.logger.csv\_print module +=========================== + +.. automodule:: pe.logger.csv_print + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.logger.image_file.rst b/doc/source/api/pe.logger.image_file.rst new file mode 100644 index 0000000..4c1cd89 --- /dev/null +++ b/doc/source/api/pe.logger.image_file.rst @@ -0,0 +1,7 @@ +pe.logger.image\_file module +============================ + +.. automodule:: pe.logger.image_file + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.logger.log_print.rst b/doc/source/api/pe.logger.log_print.rst new file mode 100644 index 0000000..50b5724 --- /dev/null +++ b/doc/source/api/pe.logger.log_print.rst @@ -0,0 +1,7 @@ +pe.logger.log\_print module +=========================== + +.. automodule:: pe.logger.log_print + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.logger.logger.rst b/doc/source/api/pe.logger.logger.rst new file mode 100644 index 0000000..8a0eca9 --- /dev/null +++ b/doc/source/api/pe.logger.logger.rst @@ -0,0 +1,7 @@ +pe.logger.logger module +======================= + +.. automodule:: pe.logger.logger + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.logger.matplotlib_pdf.rst b/doc/source/api/pe.logger.matplotlib_pdf.rst new file mode 100644 index 0000000..cc8ded0 --- /dev/null +++ b/doc/source/api/pe.logger.matplotlib_pdf.rst @@ -0,0 +1,7 @@ +pe.logger.matplotlib\_pdf module +================================ + +.. automodule:: pe.logger.matplotlib_pdf + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.logger.rst b/doc/source/api/pe.logger.rst index 5c3b302..4e686a7 100644 --- a/doc/source/api/pe.logger.rst +++ b/doc/source/api/pe.logger.rst @@ -9,42 +9,11 @@ pe.logger package Submodules ---------- -pe.logger.csv\_print module ---------------------------- - -.. automodule:: pe.logger.csv_print - :members: - :undoc-members: - :show-inheritance: - -pe.logger.image\_file module ----------------------------- - -.. automodule:: pe.logger.image_file - :members: - :undoc-members: - :show-inheritance: - -pe.logger.log\_print module ---------------------------- - -.. automodule:: pe.logger.log_print - :members: - :undoc-members: - :show-inheritance: - -pe.logger.logger module ------------------------ - -.. automodule:: pe.logger.logger - :members: - :undoc-members: - :show-inheritance: - -pe.logger.matplotlib\_pdf module --------------------------------- - -.. automodule:: pe.logger.matplotlib_pdf - :members: - :undoc-members: - :show-inheritance: +.. toctree:: + :maxdepth: 7 + + pe.logger.csv_print + pe.logger.image_file + pe.logger.log_print + pe.logger.logger + pe.logger.matplotlib_pdf diff --git a/doc/source/api/pe.population.pe_population.rst b/doc/source/api/pe.population.pe_population.rst new file mode 100644 index 0000000..9fc5222 --- /dev/null +++ b/doc/source/api/pe.population.pe_population.rst @@ -0,0 +1,7 @@ +pe.population.pe\_population module +=================================== + +.. automodule:: pe.population.pe_population + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.population.population.rst b/doc/source/api/pe.population.population.rst new file mode 100644 index 0000000..96f27f0 --- /dev/null +++ b/doc/source/api/pe.population.population.rst @@ -0,0 +1,7 @@ +pe.population.population module +=============================== + +.. automodule:: pe.population.population + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.population.rst b/doc/source/api/pe.population.rst index c7f26a6..bc4853b 100644 --- a/doc/source/api/pe.population.rst +++ b/doc/source/api/pe.population.rst @@ -9,18 +9,8 @@ pe.population package Submodules ---------- -pe.population.pe\_population module ------------------------------------ +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.population.pe_population - :members: - :undoc-members: - :show-inheritance: - -pe.population.population module -------------------------------- - -.. automodule:: pe.population.population - :members: - :undoc-members: - :show-inheritance: + pe.population.pe_population + pe.population.population diff --git a/doc/source/api/pe.rst b/doc/source/api/pe.rst index 5d6227f..67b9bb9 100644 --- a/doc/source/api/pe.rst +++ b/doc/source/api/pe.rst @@ -10,7 +10,7 @@ Subpackages ----------- .. toctree:: - :maxdepth: 3 + :maxdepth: 7 pe.api pe.callback @@ -19,6 +19,7 @@ Subpackages pe.dp pe.embedding pe.histogram + pe.llm pe.logger pe.logging pe.metric_item diff --git a/doc/source/api/pe.runner.pe.rst b/doc/source/api/pe.runner.pe.rst new file mode 100644 index 0000000..28f277d --- /dev/null +++ b/doc/source/api/pe.runner.pe.rst @@ -0,0 +1,7 @@ +pe.runner.pe module +=================== + +.. automodule:: pe.runner.pe + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.runner.rst b/doc/source/api/pe.runner.rst index f7b7b13..3f2144b 100644 --- a/doc/source/api/pe.runner.rst +++ b/doc/source/api/pe.runner.rst @@ -9,10 +9,7 @@ pe.runner package Submodules ---------- -pe.runner.pe module -------------------- +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.runner.pe - :members: - :undoc-members: - :show-inheritance: + pe.runner.pe diff --git a/doc/source/api/pe.util.download.rst b/doc/source/api/pe.util.download.rst new file mode 100644 index 0000000..b57814d --- /dev/null +++ b/doc/source/api/pe.util.download.rst @@ -0,0 +1,7 @@ +pe.util.download module +======================= + +.. automodule:: pe.util.download + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/source/api/pe.util.rst b/doc/source/api/pe.util.rst index de60c5b..e53e6f4 100644 --- a/doc/source/api/pe.util.rst +++ b/doc/source/api/pe.util.rst @@ -9,10 +9,7 @@ pe.util package Submodules ---------- -pe.util.download module ------------------------ +.. toctree:: + :maxdepth: 7 -.. automodule:: pe.util.download - :members: - :undoc-members: - :show-inheritance: + pe.util.download diff --git a/doc/source/conf.py b/doc/source/conf.py index c64d68e..4263449 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -14,8 +14,6 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = [] - templates_path = ["_templates"] exclude_patterns = [] @@ -24,8 +22,10 @@ html_theme = "sphinx_rtd_theme" # 'alabaster' html_static_path = ["_static"] +html_favicon = "icon/favicon.ico" +html_logo = "icon/icon.png" -html_theme_options = {"navigation_depth": 6} +html_theme_options = {"navigation_depth": -1, "collapse_navigation": False} extensions = [ "sphinx.ext.autodoc", @@ -38,6 +38,7 @@ "sphinx.ext.viewcode", "sphinx.ext.githubpages", "sphinx.ext.napoleon", + "sphinx_toolbox.more_autodoc.autonamedtuple", ] # Napoleon settings @@ -54,3 +55,17 @@ napoleon_use_rtype = True numfig = True +nitpicky = True +nitpick_ignore = [ + ("py:class", "optional"), + ("py:class", "abc.ABC"), + ("py:class", "np.ndarray"), + ("py:class", "fastchat.conversation.Conversation"), + ("py:class", "fastchat.model.model_adapter.BaseModelAdapter"), + ("py:class", "torch.utils.data.dataset.Dataset"), + ("py:class", "torch.nn.modules.module.Module"), + ("py:class", "Module"), + ("py:class", "pandas.DataFrame"), + ("py:class", "improved_diffusion.respace.SpacedDiffusion"), + ("py:class", "improved_diffusion.unet.UNetModel"), +] diff --git a/doc/source/getting_started/details/api.rst b/doc/source/getting_started/details/api.rst index af3ada6..f6ee978 100644 --- a/doc/source/getting_started/details/api.rst +++ b/doc/source/getting_started/details/api.rst @@ -3,10 +3,10 @@ APIs API reference: :doc:`/api/pe.api` -:py:class:`pe.api.api.API` is responsible for implementing the foundation model APIs. It has the following key methods: +:py:class:`pe.api.API` is responsible for implementing the foundation model APIs. It has the following key methods: -* :py:meth:`pe.api.api.API.random_api`: Randomly generates the synthetic samples for the initial samples of the **Private Evolution** algorithm. -* :py:meth:`pe.api.api.API.variation_api`: Generates the variations of the given synthetic samples for the initial or the next **Private Evolution** iteration. +* :py:meth:`pe.api.API.random_api`: Randomly generates the synthetic samples for the initial samples of the **Private Evolution** algorithm. +* :py:meth:`pe.api.API.variation_api`: Generates the variations of the given synthetic samples for the initial or the next **Private Evolution** iteration. Available APIs -------------- @@ -15,17 +15,21 @@ Currently, the following APIs are implemented: * Images - * :py:class:`pe.api.image.stable_diffusion_api.StableDiffusion`: The APIs of `Stable Diffusion`_. - * :py:class:`pe.api.image.improved_diffusion_api.ImprovedDiffusion`: The APIs of the `improved diffusion model`_. + * :py:class:`pe.api.StableDiffusion`: The APIs of `Stable Diffusion`_. + * :py:class:`pe.api.ImprovedDiffusion`: The APIs of the `improved diffusion model`_. * Text - * Coming soon! + * :py:class:`pe.api.LLMAugPE`: The APIs for text generation using LLMs. When constructing the instance of this API, an LLM instance is required. The LLM instances follow the interface of :py:class:`pe.llm.LLM`. Currently, the following LLMs are implemented: + + * :py:class:`pe.llm.OpenAILLM`: The LLMs from OpenAI APIs. + * :py:class:`pe.llm.AzureOpenAILLM`: The LLMs from Azure OpenAI APIs. + * :py:class:`pe.llm.HuggingfaceLLM`: The open-source LLMs from Huggingface. Adding Your Own APIs -------------------- -To add your own APIs, you need to create a class that inherits from :py:class:`pe.api.api.API` and implements the :py:meth:`pe.api.api.API.random_api` and :py:meth:`pe.api.api.API.variation_api` methods. +To add your own APIs, you need to create a class that inherits from :py:class:`pe.api.API` and implements the :py:meth:`pe.api.API.random_api` and :py:meth:`pe.api.API.variation_api` methods. .. _improved diffusion model: https://github.com/openai/improved-diffusion diff --git a/doc/source/getting_started/details/callback_and_logger.rst b/doc/source/getting_started/details/callback_and_logger.rst index cd68f30..eec56e5 100644 --- a/doc/source/getting_started/details/callback_and_logger.rst +++ b/doc/source/getting_started/details/callback_and_logger.rst @@ -3,7 +3,7 @@ Callbacks and Loggers API reference: :doc:`/api/pe.callback` and :doc:`/api/pe.logger`. -:py:class:`pe.callback.callback.Callback` can be configured to be called after each **Private Evolution** iteration with the synthetic data as the input. It is useful for computing metrics, saving the synthetic samples, monitoring the progress, etc. Each :py:class:`pe.callback.callback.Callback` can return a list of results (float numbers, images, matplotlib plots, etc.) in the form of :py:class:`pe.metric_item.MetricItem` (see :py:mod:`pe.metric_item`). All :py:class:`pe.metric_item.MetricItem` from all :py:class:`pe.callback.callback.Callback` will be passed through each of the :py:class:`pe.logger.logger.Logger` modules, which will then log the results in the desired way. +:py:class:`pe.callback.Callback` can be configured to be called after each **Private Evolution** iteration with the synthetic data as the input. It is useful for computing metrics, saving the synthetic samples, monitoring the progress, etc. Each :py:class:`pe.callback.Callback` can return a list of results (float numbers, images, matplotlib plots, etc.) in the form of :py:class:`pe.metric_item.MetricItem` (see :py:mod:`pe.metric_item`). All :py:class:`pe.metric_item.MetricItem` from all :py:class:`pe.callback.Callback` will be passed through each of the :py:class:`pe.logger.Logger` modules, which will then log the results in the desired way. Available Callbacks ------------------- @@ -12,17 +12,17 @@ Currently, the following callbacks are implemented: * For any data modality - * :py:class:`pe.callback.common.compute_fid.ComputeFID`: Computes the FID between the synthetic samples and the private samples. - * :py:class:`pe.callback.common.save_checkpoints.SaveCheckpoints`: Saves the checkpoint of current synthetic samples to files. + * :py:class:`pe.callback.ComputeFID`: Computes the FID between the synthetic samples and the private samples. + * :py:class:`pe.callback.SaveCheckpoints`: Saves the checkpoint of current synthetic samples to files. * Images - * :py:class:`pe.callback.image.sample_images.SampleImages`: Samples some images from each class. - * :py:class:`pe.callback.image.save_all_images.SaveAllImages`: Saves all synthetic images to files. + * :py:class:`pe.callback.SampleImages`: Samples some images from each class. + * :py:class:`pe.callback.SaveAllImages`: Saves all synthetic images to files. * Text - * Coming soon! + * :py:class:`pe.callback.SaveTextToCSV`: Save all text samples to a CSV file. Available Loggers @@ -30,7 +30,7 @@ Available Loggers Currently, the following loggers are implemented: -* :py:class:`pe.logger.csv_print.CSVPrint`: Saves the float results to a CSV file. -* :py:class:`pe.logger.log_print.LogPrint`: Prints the float results to the console and/or files using the logging module. -* :py:class:`pe.logger.image_file.ImageFile`: Saves the images to files. -* :py:class:`pe.logger.matplotlib_pdf.MatplotlibPDF`: Saves the matplotlib plots to PDF files. +* :py:class:`pe.logger.CSVPrint`: Saves the float results to a CSV file. +* :py:class:`pe.logger.LogPrint`: Prints the float results to the console and/or files using the logging module. +* :py:class:`pe.logger.ImageFile`: Saves the images to files. +* :py:class:`pe.logger.MatplotlibPDF`: Saves the matplotlib plots to PDF files. diff --git a/doc/source/getting_started/details/data.rst b/doc/source/getting_started/details/data.rst index 5d619a8..02bfefc 100644 --- a/doc/source/getting_started/details/data.rst +++ b/doc/source/getting_started/details/data.rst @@ -3,8 +3,8 @@ Data API reference: :doc:`/api/pe.data`. -:py:class:`pe.data.data.Data` is the base class for holding the synthetic samples or the private samples, along with their metadata. Different components are mostly communicated through objects of this class. -:py:class:`pe.data.data.Data` has two key attributes: +:py:class:`pe.data.Data` is the base class for holding the synthetic samples or the private samples, along with their metadata. Different components are mostly communicated through objects of this class. +:py:class:`pe.data.Data` has two key attributes: * ``data_frame``: A pandas_ DataFrame that holds the samples. Each row in the DataFrame is a sample, and each column is part of the sample (e.g., the image, the text, the label) and other information of the sample (e.g., its embedding produced by :doc:`embedding`). * ``metadata``: A OmegaConf_ that holds the metadata of the samples, such as the **Private Evolution** iteration number when the samples are generated, and the label names of the classes. @@ -12,22 +12,25 @@ API reference: :doc:`/api/pe.data`. Available Datasets ------------------ -For convenience, some well-known datasets are already packaged as `pe.data.data.Data` classes: +For convenience, some well-known datasets are already packaged as :py:class:`pe.data.Data` classes: * Image datasets - * :py:class:`pe.data.image.cifar10.Cifar10`: The `CIFAR10 dataset`_. - * :py:class:`pe.data.image.camelyon17.Camelyon17`: The `Camelyon17 dataset`_. - * :py:class:`pe.data.image.cat.Cat`: The `Cat dataset`_. - * In addition, you can easily load a custom image dataset from a (nested) directory with the image files using :py:meth:`pe.data.image.image.load_image_folder`. + * :py:class:`pe.data.Cifar10`: The `CIFAR10 dataset`_. + * :py:class:`pe.data.Camelyon17`: The `Camelyon17 dataset`_. + * :py:class:`pe.data.Cat`: The `Cat dataset`_. + * In addition, you can easily load a custom image dataset from a (nested) directory with the image files using :py:meth:`pe.data.load_image_folder`. * Text datasets - * Coming soon! + * :py:class:`pe.data.Yelp`: The `Yelp dataset`_. + * :py:class:`pe.data.OpenReview`: The `OpenReview dataset`_. + * :py:class:`pe.data.PubMed`: The `PubMed dataset`_. + * In addition, you can easily load a custom text dataset from a CSV file using :py:class:`pe.data.TextCSV`. Using Your Own Datasets ----------------------- -To apply **Private Evolution** to your own private dataset, you need to create a :py:class:`pe.data.data.Data` object that holds your dataset, with two parameters, ``data_frame`` and ``metadata``, passed to the constructor: +To apply **Private Evolution** to your own private dataset, you need to create a :py:class:`pe.data.Data` object that holds your dataset, with two parameters, ``data_frame`` and ``metadata``, passed to the constructor: * ``data_frame``: A pandas_ DataFrame that holds the samples. Each row in the DataFrame is a sample. The following columns must be included: @@ -47,3 +50,6 @@ To apply **Private Evolution** to your own private dataset, you need to create a .. _Cat dataset: https://www.kaggle.com/datasets/fjxmlzn/cat-cookie-doudou .. _CIFAR10 dataset: https://www.cs.toronto.edu/~kriz/cifar.html .. _Camelyon17 dataset: https://camelyon17.grand-challenge.org/ +.. _Yelp dataset: https://github.com/AI-secure/aug-pe/tree/main/data +.. _OpenReview dataset: https://github.com/AI-secure/aug-pe/tree/main/data +.. _PubMed dataset: https://github.com/AI-secure/aug-pe/tree/main/data diff --git a/doc/source/getting_started/details/dp.rst b/doc/source/getting_started/details/dp.rst index 2a3e57c..bb65e3f 100644 --- a/doc/source/getting_started/details/dp.rst +++ b/doc/source/getting_started/details/dp.rst @@ -3,14 +3,14 @@ DP API reference: :doc:`/api/pe.dp`. -:py:class:`pe.dp.dp.DP` is responsible for implementing the differential privacy mechanism. It has the following key methods: +:py:class:`pe.dp.DP` is responsible for implementing the differential privacy mechanism. It has the following key methods: -* :py:meth:`pe.dp.dp.DP.set_epsilon_and_delta`: Set the privacy budget for the differential privacy mechanism. -* :py:meth:`pe.dp.dp.DP.add_noise`: Add noise to the histogram values to achieve differential privacy. +* :py:meth:`pe.dp.DP.set_epsilon_and_delta`: Set the privacy budget for the differential privacy mechanism. +* :py:meth:`pe.dp.DP.add_noise`: Add noise to the histogram values to achieve differential privacy. Available Differential Privacy Mechanisms ----------------------------------------- Currently, the following differential privacy mechanisms are implemented: -* :py:class:`pe.dp.gaussian.Gaussian`: The Gaussian mechanism, which adds Gaussian noise to the histogram values. \ No newline at end of file +* :py:class:`pe.dp.Gaussian`: The Gaussian mechanism, which adds Gaussian noise to the histogram values. \ No newline at end of file diff --git a/doc/source/getting_started/details/embedding.rst b/doc/source/getting_started/details/embedding.rst index acf93ec..e104491 100644 --- a/doc/source/getting_started/details/embedding.rst +++ b/doc/source/getting_started/details/embedding.rst @@ -5,8 +5,8 @@ API reference: :doc:`/api/pe.embedding`. :py:class:`pe.embedding.embedding.Embedding` is responsible for computing the embeddings of the (synthetic or private) samples. It has the following key methods/attributes: -* :py:meth:`pe.embedding.embedding.Embedding.compute_embedding`: Computes the embeddings of the (synthetic or private) samples. -* :py:attr:`pe.embedding.embedding.Embedding.column_name`: The column name to be used when saving the embeddings in the ``data_frame`` of `pe.data.data.Data`. +* :py:meth:`pe.embedding.Embedding.compute_embedding`: Computes the embeddings of the (synthetic or private) samples. +* :py:attr:`pe.embedding.Embedding.column_name`: The column name to be used when saving the embeddings in the ``data_frame`` of `pe.data.Data`. Available Embeddings -------------------- @@ -15,8 +15,11 @@ Currently, the following embeddings are implemented: * Images - * :py:class:`pe.embedding.image.inception.Inception`: The embeddings computed using the Inception model. + * :py:class:`pe.embedding.Inception`: The embeddings computed using the Inception model. * Text - * Coming soon! + * :py:class:`pe.embedding.SentenceTransformer`: The embeddings computed using the `Sentence Transformers`_ library. + + +.. _Sentence Transformers: https://sbert.net/ diff --git a/doc/source/getting_started/details/histogram.rst b/doc/source/getting_started/details/histogram.rst index b7ee4ac..c4007ce 100644 --- a/doc/source/getting_started/details/histogram.rst +++ b/doc/source/getting_started/details/histogram.rst @@ -3,13 +3,13 @@ Histograms API reference: :doc:`/api/pe.histogram`. -:py:class:`pe.histogram.histogram.Histogram` is responsible for generating the histograms over the synthetic samples. It has the following key methods: +:py:class:`pe.histogram.Histogram` is responsible for generating the histograms over the synthetic samples. It has the following key methods: -* :py:meth:`pe.histogram.histogram.Histogram.compute_histogram`: Generates the histograms over the synthetic samples using private samples. +* :py:meth:`pe.histogram.Histogram.compute_histogram`: Generates the histograms over the synthetic samples using private samples. Available Histograms -------------------- Currently, the following histograms are implemented: -* :py:class:`pe.histogram.nearest_neighbors.NearestNeighbors`: This histogram algorithm projects the synthetic samples and the private samples into an embedding space and computes the nearest neighbor(s) of each private sample in the synthetic samples. The histogram value for each synthetic sample is the number of times it is the nearest neighbor(s) of a private sample. \ No newline at end of file +* :py:class:`pe.histogram.NearestNeighbors`: This histogram algorithm projects the synthetic samples and the private samples into an embedding space and computes the nearest neighbor(s) of each private sample in the synthetic samples. The histogram value for each synthetic sample is the number of times it is the nearest neighbor(s) of a private sample. \ No newline at end of file diff --git a/doc/source/getting_started/details/population.rst b/doc/source/getting_started/details/population.rst index 0396076..0e87cea 100644 --- a/doc/source/getting_started/details/population.rst +++ b/doc/source/getting_started/details/population.rst @@ -3,13 +3,12 @@ Population API reference: :doc:`/api/pe.population`. -:py:class:`pe.population.population.Population` is responsible for generating the initial synthetic samples and the new synthetic samples for each **Private Evolution** iteration. It has the following key methods: +:py:class:`pe.population.Population` is responsible for generating the initial synthetic samples and the new synthetic samples for each **Private Evolution** iteration. It has the following key methods: -* :py:meth:`pe.population.population.Population.initial`: Generates the initial synthetic samples. -* :py:meth:`pe.population.population.Population.next`: Generates the synthetic samples for the next **Private Evolution** iteration. +* :py:meth:`pe.population.Population.initial`: Generates the initial synthetic samples. +* :py:meth:`pe.population.Population.next`: Generates the synthetic samples for the next **Private Evolution** iteration. Available Populations --------------------- -:py:class:`pe.population.pe_population.PEPopulation` is currently the only implementation of :py:class:`pe.population.population.Population`. It supports the key population algorthms from existing **Private Evolution** papers (https://github.com/fjxmlzn/private-evolution-papers). - +:py:class:`pe.population.PEPopulation` is currently the only implementation of :py:class:`pe.population.Population`. It supports the key population algorthms from existing **Private Evolution** papers (https://github.com/fjxmlzn/private-evolution-papers). diff --git a/doc/source/getting_started/details/runner.rst b/doc/source/getting_started/details/runner.rst index 24054b0..6fab3d3 100644 --- a/doc/source/getting_started/details/runner.rst +++ b/doc/source/getting_started/details/runner.rst @@ -3,6 +3,6 @@ Runner API reference: :doc:`/api/pe.runner`. -:py:class:`pe.runner.pe.PE` manages the main **Private Evolution** algorithm by calling the other components discussed before. It has the following key methods: +:py:class:`pe.runner.PE` manages the main **Private Evolution** algorithm by calling the other components discussed before. It has the following key methods: -* :py:meth:`pe.runner.pe.PE.run`: Runs the **Private Evolution** algorithm. +* :py:meth:`pe.runner.PE.run`: Runs the **Private Evolution** algorithm. diff --git a/doc/source/getting_started/examples.rst b/doc/source/getting_started/examples.rst index a5d17cd..04d5f18 100644 --- a/doc/source/getting_started/examples.rst +++ b/doc/source/getting_started/examples.rst @@ -6,25 +6,54 @@ Here are some examples of how to use the **Private Evolution** library. Images ------ -These examples follow the experimental settings in the paper `Differentially Private Synthetic Data via Foundation Model APIs 1: Images (ICLR 2024) `__. +These examples follow the experimental settings in the paper `Differentially Private Synthetic Data via Foundation Model APIs 1: Images (ICLR 2024) `__. -* **CIFAR10**: `This example `__ shows how to generate differentially private synthetic images for the `CIFAR10 dataset`_ using the APIs from a pre-trained `ImageNet diffusion model`_. +* **CIFAR10 dataset**: `This example `__ shows how to generate differentially private synthetic images for the `CIFAR10 dataset`_ using the APIs from a pre-trained `ImageNet diffusion model`_. -* **Camelyon17**: `This example `__ shows how to generate differentially private synthetic images for the `Camelyon17 dataset`_ using the APIs from a pre-trained `ImageNet diffusion model`_. +* **Camelyon17 dadtaset**: `This example `__ shows how to generate differentially private synthetic images for the `Camelyon17 dataset`_ using the APIs from a pre-trained `ImageNet diffusion model`_. -* **Cat**: `This example `__ shows how to generate differentially private synthetic images of the `Cat dataset`_ using the APIs from `Stable Diffusion`_. +* **Cat dataset**: `This example `__ shows how to generate differentially private synthetic images of the `Cat dataset`_ using the APIs from `Stable Diffusion`_. Text ---- -Coming soon! +These examples follow the experimental settings in the paper `Differentially Private Synthetic Data via Foundation Model APIs 2: Text (ICML 2024 Spotlight) `__. + +* **Yelp dataset**: These examples show how to generate differentially private synthetic text for the `Yelp dataset`_ using LLM APIs from: + + * **OpenAI APIs**: `See example `__ + * **Huggingface models**: `See example `__ + +* **OpenReview dataset**: These examples show how to generate differentially private synthetic text for the `OpenReview dataset`_ using LLM APIs from: + + * **OpenAI APIs**: `See example `__ + * **Huggingface models**: `See example `__ + +* **PubMed dataset**: These examples show how to generate differentially private synthetic text for the `PubMed dataset`_ using LLM APIs from: + + * **OpenAI APIs**: `See example `__ + * **Huggingface models**: `See example `__ + .. _ImageNet diffusion model: https://github.com/openai/improved-diffusion .. _Stable Diffusion: https://huggingface.co/CompVis/stable-diffusion-v1-4 + .. _Cat dataset: https://www.kaggle.com/datasets/fjxmlzn/cat-cookie-doudou .. _CIFAR10 dataset: https://www.cs.toronto.edu/~kriz/cifar.html .. _Camelyon17 dataset: https://camelyon17.grand-challenge.org/ +.. _Yelp dataset: https://github.com/AI-secure/aug-pe/tree/main/data +.. _OpenReview dataset: https://github.com/AI-secure/aug-pe/tree/main/data +.. _PubMed dataset: https://github.com/AI-secure/aug-pe/tree/main/data + .. _CIFAR10 example: https://github.com/microsoft/DPSDA/blob/main/example/image/cifar10.py .. _Camelyon17 example: https://github.com/microsoft/DPSDA/blob/main/example/image/camelyon17.py .. _Cat example: https://github.com/microsoft/DPSDA/blob/main/example/image/cat.py -.. _paper: https://arxiv.org/abs/2305.15560 \ No newline at end of file +.. _Yelp OpenAI example: https://github.com/microsoft/DPSDA/blob/main/example/text/yelp_openai/main.py +.. _Yelp Huggingface example: https://github.com/microsoft/DPSDA/blob/main/example/text/yelp_huggingface/main.py +.. _Openreview OpenAI example: https://github.com/microsoft/DPSDA/blob/main/example/text/openreview_openai/main.py +.. _Openreview Huggingface example: https://github.com/microsoft/DPSDA/blob/main/example/text/openreview_huggingface/main.py +.. _PubMed OpenAI example: https://github.com/microsoft/DPSDA/blob/main/example/text/pubmed_openai/main.py +.. _PubMed Huggingface example: https://github.com/microsoft/DPSDA/blob/main/example/text/pubmed_huggingface/main.py + +.. _pe1_paper: https://arxiv.org/abs/2305.15560 +.. _pe2_paper: https://arxiv.org/abs/2403.01749 \ No newline at end of file diff --git a/doc/source/getting_started/getting_started.rst b/doc/source/getting_started/getting_started.rst index ecd0a76..9f2e916 100644 --- a/doc/source/getting_started/getting_started.rst +++ b/doc/source/getting_started/getting_started.rst @@ -9,5 +9,5 @@ Getting Started intro installation examples + using_your_own_data_apis details/details - diff --git a/doc/source/getting_started/installation.rst b/doc/source/getting_started/installation.rst index 3bf0ac8..f15c66f 100644 --- a/doc/source/getting_started/installation.rst +++ b/doc/source/getting_started/installation.rst @@ -4,18 +4,33 @@ Installation PIP --- +The Main Package +^^^^^^^^^^^^^^^^ + To install the core package of **Private Evolution**, please use the following command: .. code-block:: bash pip install "private-evolution @ git+https://github.com/microsoft/DPSDA.git" -If you are using **Private Evolution** to generate images, use the following command instead to install the package with the necessary dependencies for image generation: +Image Generation +^^^^^^^^^^^^^^^^ + +If you are using **Private Evolution** to generate **images**, use the following command instead to install the package with the necessary dependencies: .. code-block:: bash pip install "private-evolution[image] @ git+https://github.com/microsoft/DPSDA.git" +Text Generation +^^^^^^^^^^^^^^^ + +If you are using **Private Evolution** to generate **text**, use the following command instead to install the package with the necessary dependencies: + +.. code-block:: bash + + pip install "private-evolution[text] @ git+https://github.com/microsoft/DPSDA.git" + Faiss ----- diff --git a/doc/source/getting_started/intro.rst b/doc/source/getting_started/intro.rst index 48b42fd..035abca 100644 --- a/doc/source/getting_started/intro.rst +++ b/doc/source/getting_started/intro.rst @@ -37,4 +37,4 @@ If you use **Private Evolution** in your research or work, please cite the follo .. literalinclude:: pe2.bib :language: bibtex -Please see https://github.com/fjxmlzn/private-evolution-papers for the full list of **Private Evolution** papers done by the community. +Please see https://github.com/fjxmlzn/private-evolution-papers for the full list of **Private Evolution** papers and code repositories done by the community. diff --git a/doc/source/getting_started/using_your_own_data_apis.rst b/doc/source/getting_started/using_your_own_data_apis.rst new file mode 100644 index 0000000..340218d --- /dev/null +++ b/doc/source/getting_started/using_your_own_data_apis.rst @@ -0,0 +1,21 @@ +Using Your Own Data/APIs +======================== + + +To apply **Private Evolution** in your own data/domain/applications, most likely you only need to provide your own data (an object of :py:class:`pe.data.Data`) and APIs (an object of :py:class:`pe.api.API`). +The **Private Evolution** library is preloaded with popular data and APIs. You can also easily bring your own data and APIs. Here is how you can do it. + +Data +---- + +* **Preloaded datasets**: Some well-known datasets are already packaged as :py:class:`pe.data.Data` classes. Please refer to :doc:`this document
` for more details. +* **New image datasets**: You can easily load a custom image dataset from a (nested) directory with the image files using :py:meth:`pe.data.load_image_folder`. +* **New text datasets**: You can easily load a custom text dataset from a CSV file using :py:class:`pe.data.TextCSV`. +* **Beyond the above**: You can create a :py:class:`pe.data.Data` object that holds your dataset, with two parameters, ``data_frame`` and ``metadata``, passed to the constructor. The ``data_frame`` is a pandas DataFrame that holds the samples, and the ``metadata`` is a dictionary that holds the metadata of the samples. Please refer to :doc:`this document
` for more details. + +APIs +---- + +* **Preloaded APIs**: Some well-known APIs used in prior **Private Evolution** papers are already packaged as :py:class:`pe.api.API` classes. Please refer to :doc:`this document
` for more details. +* **Beyond the above**: You can create a class that inherits from :py:class:`pe.api.API` and implements the :py:meth:`pe.api.API.random_api` and :py:meth:`pe.api.API.variation_api` methods. Please refer to :doc:`this document
` for more details. + \ No newline at end of file diff --git a/doc/source/icon/favicon.ico b/doc/source/icon/favicon.ico new file mode 100644 index 0000000..c744629 Binary files /dev/null and b/doc/source/icon/favicon.ico differ diff --git a/doc/source/icon/icon.png b/doc/source/icon/icon.png new file mode 100644 index 0000000..ef33cc4 Binary files /dev/null and b/doc/source/icon/icon.png differ diff --git a/doc/source/index.rst b/doc/source/index.rst index 1c98d94..d9c0c45 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -10,7 +10,7 @@ The source code of this **Private Evolution** library is available at https://gi .. toctree:: - :maxdepth: 5 + :maxdepth: 7 :caption: Contents: getting_started/getting_started diff --git a/example/text/openreview_huggingface/main.py b/example/text/openreview_huggingface/main.py new file mode 100644 index 0000000..497adf8 --- /dev/null +++ b/example/text/openreview_huggingface/main.py @@ -0,0 +1,81 @@ +""" +This example follows the experimental settings of the GPT-2 OpenReview experiments in the ICML 2024 Spotlight paper, +"Differentially Private Synthetic Data via Foundation Model APIs 2: Text" (https://arxiv.org/abs/2403.01749). + +The ``model_name_or_path`` parameter can be set to other models on HuggingFace. Note that we use the FastChat +library (https://github.com/lm-sys/FastChat) to manage the conversation template. If the conversation template of your +desired model is not available in FastChat, please register the conversation template in the FastChat library. See the +following link for an example: +https://github.com/microsoft/DPSDA/blob/main/pe/llm/huggingface/register_fastchat/gpt2.py + +For detailed information about parameters and APIs, please consult the documentation of the Private Evolution library: +https://microsoft.github.io/DPSDA/. +""" + +from pe.data.text import OpenReview +from pe.logging import setup_logging +from pe.runner import PE +from pe.population import PEPopulation +from pe.api.text import LLMAugPE +from pe.llm import HuggingfaceLLM +from pe.embedding.text import SentenceTransformer +from pe.histogram import NearestNeighbors +from pe.callback import SaveCheckpoints +from pe.callback import ComputeFID +from pe.callback import SaveTextToCSV +from pe.logger import CSVPrint +from pe.logger import LogPrint + +import pandas as pd +import os +import numpy as np + +pd.options.mode.copy_on_write = True + + +if __name__ == "__main__": + exp_folder = "results/text/openreview_huggingface" + current_folder = os.path.dirname(os.path.abspath(__file__)) + + setup_logging(log_file=os.path.join(exp_folder, "log.txt")) + + data = OpenReview(root_dir="/tmp/data/openreview") + llm = HuggingfaceLLM(max_completion_tokens=448, model_name_or_path="gpt2", temperature=1.2) + api = LLMAugPE( + llm=llm, + random_api_prompt_file=os.path.join(current_folder, "random_api_prompt.json"), + variation_api_prompt_file=os.path.join(current_folder, "variation_api_prompt.json"), + ) + embedding = SentenceTransformer(model="stsb-roberta-base-v2") + histogram = NearestNeighbors( + embedding=embedding, + mode="L2", + lookahead_degree=0, + ) + population = PEPopulation( + api=api, initial_variation_api_fold=6, next_variation_api_fold=6, keep_selected=True, selection_mode="rank" + ) + + save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint")) + compute_fid = ComputeFID(priv_data=data, embedding=embedding) + save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text")) + + csv_print = CSVPrint(output_folder=exp_folder) + log_print = LogPrint() + + num_private_samples = len(data.data_frame) + delta = 1.0 / num_private_samples / np.log(num_private_samples) + + pe_runner = PE( + priv_data=data, + population=population, + histogram=histogram, + callbacks=[save_checkpoints, save_text_to_csv, compute_fid], + loggers=[csv_print, log_print], + ) + pe_runner.run( + num_samples_schedule=[2000] * 11, + delta=delta, + epsilon=1.0, + checkpoint_path=os.path.join(exp_folder, "checkpoint"), + ) diff --git a/example/text/openreview_huggingface/random_api_prompt.json b/example/text/openreview_huggingface/random_api_prompt.json new file mode 100644 index 0000000..5b10cce --- /dev/null +++ b/example/text/openreview_huggingface/random_api_prompt.json @@ -0,0 +1,67 @@ +{ + "message_template": [ + { + "role": "user", + "content": "Suppose that you are a {writer}. Write a paper review based on Area: {area}\tRecommendation: {recommendation}" + } + ], + "replacement_rules": [ + { + "constraints": {}, + "replacements": { + "writer": [ + "Senior Research Scientist who has Expert and detailed analysis", + "Ph.D. Student who has Inquisitive and learning-focused tone", + "Industry Practitioner who has Practical and application-oriented critique", + "Academic Professor who has Theoretical and pedagogical perspective", + "AI Ethicist who has Focus on societal impacts and ethical considerations", + "Data Analyst who has Data-driven and statistical approach", + "Software Developer who has Emphasis on implementation and efficiency", + "Peer Reviewer who has Critical and thorough examination", + "Conference Attendee who has General and broad overview", + "Machine Learning Enthusiast who has Curious and eager tone", + "Journal Editor who has Editorial and concise commentary", + "Graduate Research Assistant who has Detailed and diligent analysis", + "Algorithm Specialist who has Technical and focused on algorithmic aspects", + "Statistician who has Emphasis on statistical methods and validity", + "Postdoctoral Researcher who has Advanced and knowledgeable insights", + "Technology Consultant who has Business and application-oriented perspective", + "AI Policy Maker who has Concerned with regulatory and policy implications", + "Robotics Engineer who has Focus on practical applications in robotics", + "Computer Science Undergraduate who has Beginner-level understanding and curiosity", + "Innovator in Tech who has Emphasis on novelty and innovation", + "Venture Capitalist who has Interested in commercial potential and scalability", + "User Experience Designer who has Concerned with usability and design implications", + "Neuroscientist who has Interested in the intersection with cognitive science", + "Biostatistician who has Focus on applications in biological data", + "Quantitative Analyst who has Rigorous financial or econometric perspective", + "Computational Linguist who has Focus on natural language processing applications", + "AI Hobbyist who has Enthusiastic but less formal tone", + "Independent Researcher who has Unique and potentially unorthodox views", + "High School Science Teacher who has Simplified and educational perspective", + "Government Research Analyst who has Policy and societal impact focus", + "AI Safety Researcher who has Concerned with long-term impacts and safety", + "Technology Journalist who has Accessible and general-public-oriented review", + "Open Source Contributor who has Emphasis on community and collaboration", + "Clinical Researcher who has Perspective on healthcare applications", + "Entrepreneur in AI who has Business-oriented and innovation-driven", + "Legal Expert in Tech who has Focus on legal and regulatory aspects", + "Environmental Scientist who has Concern with ecological and sustainability aspects", + "Ethical Hacker who has Security-focused and critical perspective", + "Lecturer in Computer Science who has Educational and structured critique", + "Artificial Intelligence Critic who has Skeptical and challenging views", + "Tech Support Specialist who has Practical and user-focused angle", + "Multimedia Artist who has Interest in creative AI applications", + "Public Speaker in Tech who has Persuasive and impactful tone", + "International Research Collaborator who has Global and diverse perspective", + "Cybersecurity Expert who has Focus on data privacy and security issues", + "Philosopher of Science who has Conceptual and philosophical approach", + "Mathematician who has Rigorous and formal mathematical analysis", + "Science Fiction Writer who has Imaginative and speculative angle", + "Psychologist who has Interested in human-AI interaction", + "Skeptical General Public who has Layman's perspective and common misconceptions" + ] + } + } + ] +} \ No newline at end of file diff --git a/example/text/openreview_huggingface/variation_api_prompt.json b/example/text/openreview_huggingface/variation_api_prompt.json new file mode 100644 index 0000000..4615563 --- /dev/null +++ b/example/text/openreview_huggingface/variation_api_prompt.json @@ -0,0 +1,23 @@ +{ + "message_template": [ + { + "role": "user", + "content": "Based on Area: {area}\tRecommendation: {recommendation}, please rephrase the following sentences {tone} as a paper review:\n{sample}" + } + ], + "replacement_rules": [ + { + "constraints": {}, + "replacements": { + "tone": [ + "in a detailed way", + "in a professional way", + "with more details", + "with a professional tone", + "in a professional style", + "in a concise manner" + ] + } + } + ] +} \ No newline at end of file diff --git a/example/text/openreview_openai/main.py b/example/text/openreview_openai/main.py new file mode 100644 index 0000000..3c586fe --- /dev/null +++ b/example/text/openreview_openai/main.py @@ -0,0 +1,106 @@ +""" +This example follows the experimental settings of the GPT-3.5 OpenReview experiments in the ICML 2024 Spotlight paper, +"Differentially Private Synthetic Data via Foundation Model APIs 2: Text" (https://arxiv.org/abs/2403.01749), except +that the model is changed from GPT-3.5 to GPT-4o-mini-2024-07-18 as the original GPT-3.5 model version used in the +paper is no longer available. + +To run the code, the following environment variables are required: +* OPENAI_API_KEY: OpenAI API key. You can get it from https://platform.openai.com/account/api-keys. Multiple keys can + be separated by commas, and a key will be selected randomly for each request. + +We can also switch from OpenAI API to Azure OpenAI API by using :py:class:`pe.llm.azure_openai.AzureOpenAILLM` instead +of :py:class:`pe.llm.openai.OpenAILLM`. In that case, the following environment variables are required: +* AZURE_OPENAI_API_KEY: Azure OpenAI API key. You can get it from https://portal.azure.com/. Multiple keys can be + separated by commas, and a key will be selected randomly for each request. The key can also be "AZ_CLI", in which + case the Azure CLI will be used to authenticate the requests, and the environment variable AZURE_OPENAI_API_SCOPE + needs to be set. See Azure OpenAI authentication documentation for more information: + https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints#microsoft-entra-id-authentication +* AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint. You can get it from https://portal.azure.com/. +* AZURE_OPENAI_API_VERSION: Azure OpenAI API version. You can get it from https://portal.azure.com/. + +These environment variables can be set in a .env file in the same directory as this script. For example: +``` +OPENAI_API_KEY=your_openai_api_key +``` +See https://github.com/theskumar/python-dotenv for more information about the .env file. + +For detailed information about parameters and APIs, please consult the documentation of the Private Evolution library: +https://microsoft.github.io/DPSDA/. +""" + +from dotenv import load_dotenv + +from pe.data.text import OpenReview +from pe.logging import setup_logging +from pe.runner import PE +from pe.population import PEPopulation +from pe.api.text import LLMAugPE +from pe.llm import OpenAILLM +from pe.embedding.text import SentenceTransformer +from pe.histogram import NearestNeighbors +from pe.callback import SaveCheckpoints +from pe.callback import ComputeFID +from pe.callback import SaveTextToCSV +from pe.logger import CSVPrint +from pe.logger import LogPrint + +import pandas as pd +import os +import numpy as np + +pd.options.mode.copy_on_write = True + + +if __name__ == "__main__": + exp_folder = "results/text/openreview_openai_api" + current_folder = os.path.dirname(os.path.abspath(__file__)) + + load_dotenv() + + setup_logging(log_file=os.path.join(exp_folder, "log.txt")) + + data = OpenReview(root_dir="/tmp/data/openreview") + llm = OpenAILLM(max_completion_tokens=1000, model="gpt-4o-mini-2024-07-18", temperature=1.2, num_threads=4) + api = LLMAugPE( + llm=llm, + random_api_prompt_file=os.path.join(current_folder, "random_api_prompt.json"), + variation_api_prompt_file=os.path.join(current_folder, "variation_api_prompt.json"), + min_word_count=25, + word_count_std=30, + token_to_word_ratio=5, + max_completion_tokens_limit=1200, + blank_probabilities=0.5, + ) + embedding = SentenceTransformer(model="stsb-roberta-base-v2") + histogram = NearestNeighbors( + embedding=embedding, + mode="L2", + lookahead_degree=0, + ) + population = PEPopulation( + api=api, initial_variation_api_fold=3, next_variation_api_fold=3, keep_selected=True, selection_mode="rank" + ) + + save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint")) + compute_fid = ComputeFID(priv_data=data, embedding=embedding) + save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text")) + + csv_print = CSVPrint(output_folder=exp_folder) + log_print = LogPrint() + + num_private_samples = len(data.data_frame) + delta = 1.0 / num_private_samples / np.log(num_private_samples) + + pe_runner = PE( + priv_data=data, + population=population, + histogram=histogram, + callbacks=[save_checkpoints, save_text_to_csv, compute_fid], + loggers=[csv_print, log_print], + ) + pe_runner.run( + num_samples_schedule=[2000] * 11, + delta=delta, + epsilon=1.0, + checkpoint_path=os.path.join(exp_folder, "checkpoint"), + ) diff --git a/example/text/openreview_openai/random_api_prompt.json b/example/text/openreview_openai/random_api_prompt.json new file mode 100644 index 0000000..dcdf1d0 --- /dev/null +++ b/example/text/openreview_openai/random_api_prompt.json @@ -0,0 +1,12 @@ +{ + "message_template": [ + { + "role": "system", + "content": "Given the area and final decision of a research paper, you are required to provide a **detailed and long** review consisting of the following content: 1. briefly summarizing the paper in 3-5 sentences; 2. listing the strengths and weaknesses of the paper in details; 3. briefly summarizing the review in 3-5 sentences." + }, + { + "role": "user", + "content": "Area: {area}\tRecommendation: {recommendation}" + } + ] +} \ No newline at end of file diff --git a/example/text/openreview_openai/variation_api_prompt.json b/example/text/openreview_openai/variation_api_prompt.json new file mode 100644 index 0000000..124fc16 --- /dev/null +++ b/example/text/openreview_openai/variation_api_prompt.json @@ -0,0 +1,12 @@ +{ + "message_template": [ + { + "role": "system", + "content": "You are an AI assistant that helps people find information." + }, + { + "role": "user", + "content": "Based on the area and final decision of a research paper, you are required to fill in the blanks for the input sentences **in a concise manner**. If there is no blanks, please output the original input sentences.\nArea: Applications (eg, speech processing, computer vision, NLP)\tRecommendation: 3: reject, not good enough.\nInput: __ proposes an__ method_ ROI detection__arial_f_ without attention_. The_ map can_ used____ for__ and____ show_ improvements on different medical__._Strength__ \n--The idea using__actual images_ sali__ generation_ interesting.\n\n_The improvement____aks is significant. \n\nWeak____The___ and_____ experiments are needed_ such as__f___the_ method_ interesting_ but_ novelty_ limited\nFill-in-Blanks and your answer MUST be exactly 85 words: This paper proposes an attention generation method for ROI detection by adversarial counterfactual without attention label. The attention map can be used to highlight useful information for disease classification and detection. The experiments show its improvements on different medical imaging tasks. \nStrengths: \n--The idea using counterfactual images for saliency map generation is interesting.\n\n--The improvement for medical imaging taks is significant. \n\nWeaknesses:\n\n--The novelty is simple and limited. \n\n--More experiments are needed, such as existing counterfactual generation.\nthe proposed method is interesting, but the novelty is limited\n\n\nArea: {area}\tRecommendation: {recommendation}.\nInput: {masked_sample}\nFill-in-Blanks and your answer MUST be exactly {word_count} words:" + } + ] +} \ No newline at end of file diff --git a/example/text/pubmed_huggingface/main.py b/example/text/pubmed_huggingface/main.py new file mode 100644 index 0000000..0ec4371 --- /dev/null +++ b/example/text/pubmed_huggingface/main.py @@ -0,0 +1,81 @@ +""" +This example follows the experimental settings of the GPT-2 PubMed experiments in the ICML 2024 Spotlight paper, +"Differentially Private Synthetic Data via Foundation Model APIs 2: Text" (https://arxiv.org/abs/2403.01749). + +The ``model_name_or_path`` parameter can be set to other models on HuggingFace. Note that we use the FastChat +library (https://github.com/lm-sys/FastChat) to manage the conversation template. If the conversation template of your +desired model is not available in FastChat, please register the conversation template in the FastChat library. See the +following link for an example: +https://github.com/microsoft/DPSDA/blob/main/pe/llm/huggingface/register_fastchat/gpt2.py + +For detailed information about parameters and APIs, please consult the documentation of the Private Evolution library: +https://microsoft.github.io/DPSDA/. +""" + +from pe.data.text import PubMed +from pe.logging import setup_logging +from pe.runner import PE +from pe.population import PEPopulation +from pe.api.text import LLMAugPE +from pe.llm import HuggingfaceLLM +from pe.embedding.text import SentenceTransformer +from pe.histogram import NearestNeighbors +from pe.callback import SaveCheckpoints +from pe.callback import ComputeFID +from pe.callback import SaveTextToCSV +from pe.logger import CSVPrint +from pe.logger import LogPrint + +import pandas as pd +import os +import numpy as np + +pd.options.mode.copy_on_write = True + + +if __name__ == "__main__": + exp_folder = "results/text/pubmed_huggingface" + current_folder = os.path.dirname(os.path.abspath(__file__)) + + setup_logging(log_file=os.path.join(exp_folder, "log.txt")) + + data = PubMed(root_dir="/tmp/data/pubmed") + llm = HuggingfaceLLM(max_completion_tokens=448, model_name_or_path="gpt2", temperature=1.0) + api = LLMAugPE( + llm=llm, + random_api_prompt_file=os.path.join(current_folder, "random_api_prompt.json"), + variation_api_prompt_file=os.path.join(current_folder, "variation_api_prompt.json"), + ) + embedding = SentenceTransformer(model="sentence-t5-base") + histogram = NearestNeighbors( + embedding=embedding, + mode="L2", + lookahead_degree=0, + ) + population = PEPopulation( + api=api, initial_variation_api_fold=6, next_variation_api_fold=6, keep_selected=True, selection_mode="rank" + ) + + save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint")) + compute_fid = ComputeFID(priv_data=data, embedding=embedding) + save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text")) + + csv_print = CSVPrint(output_folder=exp_folder) + log_print = LogPrint() + + num_private_samples = len(data.data_frame) + delta = 1.0 / num_private_samples / np.log(num_private_samples) + + pe_runner = PE( + priv_data=data, + population=population, + histogram=histogram, + callbacks=[save_checkpoints, save_text_to_csv, compute_fid], + loggers=[csv_print, log_print], + ) + pe_runner.run( + num_samples_schedule=[2000] * 11, + delta=delta, + epsilon=1.0, + checkpoint_path=os.path.join(exp_folder, "checkpoint"), + ) diff --git a/example/text/pubmed_huggingface/random_api_prompt.json b/example/text/pubmed_huggingface/random_api_prompt.json new file mode 100644 index 0000000..adecee6 --- /dev/null +++ b/example/text/pubmed_huggingface/random_api_prompt.json @@ -0,0 +1,8 @@ +{ + "message_template": [ + { + "role": "user", + "content": "Using a variety of sentence structures, write an abstract for a medical research paper:" + } + ] +} \ No newline at end of file diff --git a/example/text/pubmed_huggingface/variation_api_prompt.json b/example/text/pubmed_huggingface/variation_api_prompt.json new file mode 100644 index 0000000..140468b --- /dev/null +++ b/example/text/pubmed_huggingface/variation_api_prompt.json @@ -0,0 +1,26 @@ +{ + "message_template": [ + { + "role": "user", + "content": "Please rephrase the following sentences {tone} as an abstract for medical research paper:\n{sample}" + } + ], + "replacement_rules": [ + { + "constraints": {}, + "replacements": { + "tone": [ + "in a professional way", + "in a professional tone", + "in a professional style", + "in a concise manner", + "in a creative style", + "using imagination", + "in a storytelling tone", + "in a formal manner", + "using a variety of sentence structures" + ] + } + } + ] +} \ No newline at end of file diff --git a/example/text/pubmed_openai/main.py b/example/text/pubmed_openai/main.py new file mode 100644 index 0000000..0c37e03 --- /dev/null +++ b/example/text/pubmed_openai/main.py @@ -0,0 +1,106 @@ +""" +This example follows the experimental settings of the GPT-3.5 PubMed experiments in the ICML 2024 Spotlight paper, +"Differentially Private Synthetic Data via Foundation Model APIs 2: Text" (https://arxiv.org/abs/2403.01749), except +that the model is changed from GPT-3.5 to GPT-4o-mini-2024-07-18 as the original GPT-3.5 model version used in the +paper is no longer available. + +To run the code, the following environment variables are required: +* OPENAI_API_KEY: OpenAI API key. You can get it from https://platform.openai.com/account/api-keys. Multiple keys can + be separated by commas, and a key will be selected randomly for each request. + +We can also switch from OpenAI API to Azure OpenAI API by using :py:class:`pe.llm.azure_openai.AzureOpenAILLM` instead +of :py:class:`pe.llm.openai.OpenAILLM`. In that case, the following environment variables are required: +* AZURE_OPENAI_API_KEY: Azure OpenAI API key. You can get it from https://portal.azure.com/. Multiple keys can be + separated by commas, and a key will be selected randomly for each request. The key can also be "AZ_CLI", in which + case the Azure CLI will be used to authenticate the requests, and the environment variable AZURE_OPENAI_API_SCOPE + needs to be set. See Azure OpenAI authentication documentation for more information: + https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints#microsoft-entra-id-authentication +* AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint. You can get it from https://portal.azure.com/. +* AZURE_OPENAI_API_VERSION: Azure OpenAI API version. You can get it from https://portal.azure.com/. + +These environment variables can be set in a .env file in the same directory as this script. For example: +``` +OPENAI_API_KEY=your_openai_api_key +``` +See https://github.com/theskumar/python-dotenv for more information about the .env file. + +For detailed information about parameters and APIs, please consult the documentation of the Private Evolution library: +https://microsoft.github.io/DPSDA/. +""" + +from dotenv import load_dotenv + +from pe.data.text import PubMed +from pe.logging import setup_logging +from pe.runner import PE +from pe.population import PEPopulation +from pe.api.text import LLMAugPE +from pe.llm import OpenAILLM +from pe.embedding.text import SentenceTransformer +from pe.histogram import NearestNeighbors +from pe.callback import SaveCheckpoints +from pe.callback import ComputeFID +from pe.callback import SaveTextToCSV +from pe.logger import CSVPrint +from pe.logger import LogPrint + +import pandas as pd +import os +import numpy as np + +pd.options.mode.copy_on_write = True + + +if __name__ == "__main__": + exp_folder = "results/text/pubmed_openai_api" + current_folder = os.path.dirname(os.path.abspath(__file__)) + + load_dotenv() + + setup_logging(log_file=os.path.join(exp_folder, "log.txt")) + + data = PubMed(root_dir="/tmp/data/pubmed") + llm = OpenAILLM(max_completion_tokens=1000, model="gpt-4o-mini-2024-07-18", temperature=1.2, num_threads=4) + api = LLMAugPE( + llm=llm, + random_api_prompt_file=os.path.join(current_folder, "random_api_prompt.json"), + variation_api_prompt_file=os.path.join(current_folder, "variation_api_prompt.json"), + min_word_count=25, + word_count_std=36, + token_to_word_ratio=5, + max_completion_tokens_limit=1200, + blank_probabilities=0.6, + ) + embedding = SentenceTransformer(model="sentence-t5-base") + histogram = NearestNeighbors( + embedding=embedding, + mode="L2", + lookahead_degree=0, + ) + population = PEPopulation( + api=api, initial_variation_api_fold=3, next_variation_api_fold=3, keep_selected=True, selection_mode="rank" + ) + + save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint")) + compute_fid = ComputeFID(priv_data=data, embedding=embedding) + save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text")) + + csv_print = CSVPrint(output_folder=exp_folder) + log_print = LogPrint() + + num_private_samples = len(data.data_frame) + delta = 1.0 / num_private_samples / np.log(num_private_samples) + + pe_runner = PE( + priv_data=data, + population=population, + histogram=histogram, + callbacks=[save_checkpoints, save_text_to_csv, compute_fid], + loggers=[csv_print, log_print], + ) + pe_runner.run( + num_samples_schedule=[2000] * 11, + delta=delta, + epsilon=1.0, + checkpoint_path=os.path.join(exp_folder, "checkpoint"), + ) diff --git a/example/text/pubmed_openai/random_api_prompt.json b/example/text/pubmed_openai/random_api_prompt.json new file mode 100644 index 0000000..34657eb --- /dev/null +++ b/example/text/pubmed_openai/random_api_prompt.json @@ -0,0 +1,71 @@ +{ + "message_template": [ + { + "role": "system", + "content": "Please act as a sentence generator for the medical domain. Generated sentences should mimic the style of PubMed journal articles, using a variety of sentence structures." + }, + { + "role": "user", + "content": "Suppose that you are a {writer}. Please provide an example of an abstract for a medical research paper:" + } + ], + "replacement_rules": [ + { + "constraints": {}, + "replacements": { + "writer": [ + "Clinical Researcher", + "Principal Investigator", + "Biomedical Engineer", + "Psychologist", + "Endocrinologist", + "Oncologist", + "Neurologist", + "Epidemiologist", + "Biostatistician", + "Medical Reviewer", + "Laboratory Technician", + "Pharmaceutics Expert", + "Geneticist", + "Immunologist", + "Public Health Official", + "Clinical Trial Coordinator", + "Medical Ethicist", + "Healthcare Policy Maker", + "Medical Journal Editor", + "Peer Reviewer", + "Pharmacologist", + "Graduate Medical Student", + "Postdoctoral Researcher", + "Medical Librarian", + "Nursing Researcher", + "Physiologist", + "Toxicologist", + "Pathologist", + "Radiologist", + "Surgical Researcher", + "Medical Historian", + "Biogerontologist", + "Bioinformatician", + "Gynecologist", + "Pediatrician", + "Dermatologist", + "Orthopedic Surgeon", + "Anesthesiologist", + "Cardiologist", + "Virologist", + "Molecular Biologist", + "Nutritionist", + "Sports Medicine Specialist", + "Rehabilitation Specialist", + "Health Economist", + "Patient Advocate", + "Bioethics Committee", + "Clinical Data Manager", + "Medical Statistician", + "Health Technology Assessor" + ] + } + } + ] +} \ No newline at end of file diff --git a/example/text/pubmed_openai/variation_api_prompt.json b/example/text/pubmed_openai/variation_api_prompt.json new file mode 100644 index 0000000..101fea1 --- /dev/null +++ b/example/text/pubmed_openai/variation_api_prompt.json @@ -0,0 +1,12 @@ +{ + "message_template": [ + { + "role": "system", + "content": "Please act as a sentence generator for the medical domain. Generated sentences should mimic the style of PubMed journal articles, using a variety of sentence structures." + }, + { + "role": "user", + "content": "You are required to fill in the blanks with more details for the input medical abstract in a professional tone. If there is no blanks, please output the original medical abstract.\nPlease fill in the blanks in the following sentences to write an abstract of a medical research paper: \"{masked_sample}\" and your answer MUST be exactly {word_count} words.\n" + } + ] +} \ No newline at end of file diff --git a/example/text/yelp_huggingface/main.py b/example/text/yelp_huggingface/main.py new file mode 100644 index 0000000..d90e546 --- /dev/null +++ b/example/text/yelp_huggingface/main.py @@ -0,0 +1,81 @@ +""" +This example follows the experimental settings of the GPT-2 Yelp experiments in the ICML 2024 Spotlight paper, +"Differentially Private Synthetic Data via Foundation Model APIs 2: Text" (https://arxiv.org/abs/2403.01749). + +The ``model_name_or_path`` parameter can be set to other models on HuggingFace. Note that we use the FastChat +library (https://github.com/lm-sys/FastChat) to manage the conversation template. If the conversation template of your +desired model is not available in FastChat, please register the conversation template in the FastChat library. See the +following link for an example: +https://github.com/microsoft/DPSDA/blob/main/pe/llm/huggingface/register_fastchat/gpt2.py + +For detailed information about parameters and APIs, please consult the documentation of the Private Evolution library: +https://microsoft.github.io/DPSDA/. +""" + +from pe.data.text import Yelp +from pe.logging import setup_logging +from pe.runner import PE +from pe.population import PEPopulation +from pe.api.text import LLMAugPE +from pe.llm import HuggingfaceLLM +from pe.embedding.text import SentenceTransformer +from pe.histogram import NearestNeighbors +from pe.callback import SaveCheckpoints +from pe.callback import ComputeFID +from pe.callback import SaveTextToCSV +from pe.logger import CSVPrint +from pe.logger import LogPrint + +import pandas as pd +import os +import numpy as np + +pd.options.mode.copy_on_write = True + + +if __name__ == "__main__": + exp_folder = "results/text/yelp_huggingface" + current_folder = os.path.dirname(os.path.abspath(__file__)) + + setup_logging(log_file=os.path.join(exp_folder, "log.txt")) + + data = Yelp(root_dir="/tmp/data/yelp") + llm = HuggingfaceLLM(max_completion_tokens=64, model_name_or_path="gpt2", temperature=1.4) + api = LLMAugPE( + llm=llm, + random_api_prompt_file=os.path.join(current_folder, "random_api_prompt.json"), + variation_api_prompt_file=os.path.join(current_folder, "variation_api_prompt.json"), + ) + embedding = SentenceTransformer(model="stsb-roberta-base-v2") + histogram = NearestNeighbors( + embedding=embedding, + mode="L2", + lookahead_degree=0, + ) + population = PEPopulation( + api=api, initial_variation_api_fold=6, next_variation_api_fold=6, keep_selected=True, selection_mode="rank" + ) + + save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint")) + compute_fid = ComputeFID(priv_data=data, embedding=embedding) + save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text")) + + csv_print = CSVPrint(output_folder=exp_folder) + log_print = LogPrint() + + num_private_samples = len(data.data_frame) + delta = 1.0 / num_private_samples / np.log(num_private_samples) + + pe_runner = PE( + priv_data=data, + population=population, + histogram=histogram, + callbacks=[save_checkpoints, save_text_to_csv, compute_fid], + loggers=[csv_print, log_print], + ) + pe_runner.run( + num_samples_schedule=[5000] * 21, + delta=delta, + epsilon=1.0, + checkpoint_path=os.path.join(exp_folder, "checkpoint"), + ) diff --git a/example/text/yelp_huggingface/random_api_prompt.json b/example/text/yelp_huggingface/random_api_prompt.json new file mode 100644 index 0000000..a057199 --- /dev/null +++ b/example/text/yelp_huggingface/random_api_prompt.json @@ -0,0 +1,1100 @@ +{ + "message_template": [ + { + "role": "user", + "content": "Business Category: {business_category}\tReview Stars: {review_stars} with keyword {keyword}" + } + ], + "replacement_rules": [ + { + "constraints": { + "business_category": "Arts & Entertainment" + }, + "replacements": { + "keyword": [ + "Art Galleries", + "Museums", + "Live Music Venues", + "Theaters", + "Dance Studios", + "Comedy Clubs", + "Film Festivals", + "Performing Arts Centers", + "Concert Halls", + "Jazz Clubs", + "Opera Houses", + "Symphony Orchestras", + "Ballet Companies", + "Art Exhibitions", + "Street Performances", + "Improv Shows", + "Stand-Up Comedy", + "Music Festivals", + "Film Screening Events", + "Art Workshops", + "Art Classes", + "Art Installations", + "Art Fairs", + "Sculpture Gardens", + "Public Art Displays", + "Art Auctions", + "Art Museums", + "Contemporary Art Spaces", + "Ceramic Studios", + "Pottery Classes", + "Photography Exhibitions", + "Street Art Tours", + "Graffiti Art", + "Poetry Readings", + "Literary Festivals", + "Bookstores", + "Storytelling Events", + "Cabaret Shows", + "Magic Shows", + "Circus Performances", + "Puppet Shows", + "Fashion Shows", + "Body Painting", + "Burlesque Shows", + "Drag Performances", + "Spoken Word Performances", + "Opera Performances", + "Outdoor Concerts", + "Tribute Bands", + "Music Open Mic Nights", + "Indie Music Venues", + "Jazz Festivals", + "Theatre Festivals", + "Dance Performances", + "Dance Workshops", + "Dance Competitions", + "Film Premiers", + "Film Screenwriting Workshops", + "Film Production Studios", + "Animation Studios", + "Film Awards Ceremonies", + "Film Retrospectives", + "Film Noir Screenings", + "Film Documentaries", + "Independent Film Showcases", + "Film Director Q&A Sessions", + "Art History Lectures", + "Art Tours", + "Art Therapy Workshops", + "Art Supply Stores", + "Art Conservation Services", + "Public Murals", + "Live Street Art Performances", + "Art Film Screenings", + "Art Restoration Services", + "Art Book Signings", + "Art Magazine Launches", + "Artisan Markets", + "Artisan Workshops", + "DIY Craft Events", + "Virtual Reality Experiences", + "Gaming Conventions", + "Esports Tournaments", + "Cosplay Events", + "Comic Book Conventions", + "Anime Festivals", + "Pop Culture Expos", + "Board Game Cafés", + "Trivia Nights", + "Karaoke Bars", + "Outdoor Theater Performances", + "Artisan Food and Beverage Tastings", + "Historic Site Tours", + "Music Education Programs", + "Film Soundtrack Concerts", + "Art Performance Installations", + "Artisanal Food Markets", + "Circus Training Workshops", + "Live Painting Demonstrations", + "Art Film Festivals" + ] + } + }, + { + "constraints": { + "business_category": "Bars" + }, + "replacements": { + "keyword": [ + "Sports Bars", + "Dive Bars", + "Cocktail Bars", + "Brewpubs", + "Wine Bars", + "Karaoke Bars", + "Jazz Bars", + "Tiki Bars", + "Rooftop Bars", + "Irish Pubs", + "Whiskey Bars", + "Beer Gardens", + "Speakeasies", + "Neighborhood Bars", + "Gay Bars", + "Salsa Bars", + "Cigar Bars", + "Piano Bars", + "Country Bars", + "College Bars", + "Hotel Bars", + "Live Music Bars", + "Beach Bars", + "Craft Beer Bars", + "Comedy Clubs with Bars", + "Tequila Bars", + "Rum Bars", + "Gin Bars", + "Martini Bars", + "Bourbon Bars", + "Scotch Bars", + "Blues Bars", + "Reggae Bars", + "Wine Tasting Bars", + "Distillery Bars", + "Outdoor Bars", + "Latin Bars", + "Upscale Bars", + "Lounge Bars", + "Artisanal Cocktail Bars", + "Whiskey Tasting Bars", + "Cider Bars", + "Hipster Bars", + "Underground Bars", + "Cabaret Bars", + "Burlesque Bars", + "Board Game Bars", + "Arcade Bars", + "Craft Cocktail Bars", + "Sake Bars", + "Biker Bars", + "Tapas Bars", + "Microbrewery Taprooms", + "Speakeasy-Style Bars", + "Absinthe Bars", + "Vodka Bars", + "Beachfront Bars", + "80s Bars", + "90s Bars", + "Swanky Bars", + "Rum Tasting Bars", + "Irish Whiskey Bars", + "Sours Bars", + "Whiskey and Cigar Bars", + "Themed Bars", + "Mezcal Bars", + "Hawaiian Tiki Bars", + "German Beer Halls", + "Rooftop Sky Bars", + "Rustic Bars", + "Wine Cellar Bars", + "Gin and Tonic Bars", + "Underground Speakeasies", + "Jazz Speakeasies", + "Secret Bars", + "Piano Karaoke Bars", + "Reggaeton Bars", + "Bachelorette Party Bars", + "Outdoor Rooftop Bars", + "Blues and BBQ Bars", + "Rum Distillery Bars", + "Cocktail Mixology Bars", + "Margarita Bars", + "Classic Cocktail Bars", + "Nightclub Bars", + "Whiskey and BBQ Bars", + "Cabana Bars", + "Rooftop Pool Bars", + "Drag Bars", + "Wine and Cheese Bars", + "Prohibition-Style Bars", + "Tropical Bars", + "Latin Dance Bars", + "Rum Tiki Bars", + "Rooftop Lounge Bars", + "Beer Flight Bars", + "Barrel-Aged Beer Bars", + "Tropical Cocktail Bars", + "Jazz and Blues Bars", + "Outdoor Beach Bars" + ] + } + }, + { + "constraints": { + "business_category": "Beauty & Spas" + }, + "replacements": { + "keyword": [ + "Hair Salons", + "Nail Salons", + "Day Spas", + "Massage Therapy", + "Facial Services", + "Waxing Studios", + "Eyebrow Threading", + "Makeup Artists", + "Barber Shops", + "Hair Removal Services", + "Tanning Salons", + "Body Treatments", + "Manicure and Pedicure Services", + "Spa Packages", + "Skin Care Clinics", + "Medical Spas", + "Acupuncture Services", + "Ayurvedic Treatments", + "Aromatherapy", + "Reflexology Services", + "Reiki Healing", + "Holistic Wellness Centers", + "Hair Color Services", + "Hair Extensions", + "Bridal Hair and Makeup Services", + "Lash Extensions", + "Microblading Services", + "Tattoo Studios", + "Permanent Makeup", + "Blowout Services", + "Scalp Treatments", + "Brazilian Waxing", + "Eyelash Lift and Tint", + "Nail Art Studios", + "Gel Nail Services", + "Hair Braiding", + "Spa Facials", + "Hot Stone Massages", + "Deep Tissue Massages", + "Swedish Massages", + "Sports Massages", + "Thai Massages", + "Couples Massages", + "Body Scrubs", + "Body Wraps", + "Sauna Services", + "Floatation Therapy", + "Cupping Therapy", + "Oxygen Bar Services", + "Spa Manicures and Pedicures", + "Shellac Nail Services", + "Gel Polish Removal", + "Foot Massage Services", + "Back Massage Services", + "Anti-Aging Treatments", + "Chemical Peels", + "Microdermabrasion", + "Laser Hair Removal", + "Botox Injections", + "Dermal Fillers", + "Lip Enhancement", + "Facial Rejuvenation", + "Coolsculpting Services", + "Fat Reduction Treatments", + "Body Contouring", + "Cryotherapy Services", + "Spa Parties", + "Mobile Beauty Services", + "Wellness Retreats", + "Meditation Classes", + "Yoga Studios", + "Pilates Studios", + "Fitness Centers with Spa Services", + "Infrared Sauna Therapy", + "Body Piercing Studios", + "Non-Surgical Facelifts", + "Vampire Facials", + "Microneedling Services", + "Teeth Whitening Services", + "Hair Straightening Services", + "Balayage Services", + "Highlights and Lowlights", + "Keratin Treatments", + "Scalp Micropigmentation", + "Henna Tattoos", + "Laser Skin Resurfacing", + "Body Piercing Jewelry and Accessories", + "Spa Membership Programs", + "Esthetician Services", + "Reflexology Foot Spas", + "Hand and Arm Massages", + "Deep Cleansing Facials", + "Couples Spa Packages", + "Lymphatic Drainage Massages", + "Eyebrow and Eyelash Tinting", + "Beard Grooming Services", + "Spa Consultations", + "Wellness Coaching", + "Hair Loss Treatments", + "Herbal Wraps" + ] + } + }, + { + "constraints": { + "business_category": "Event Planning & Services" + }, + "replacements": { + "keyword": [ + "Wedding Planning", + "Corporate Event Planning", + "Party Planning", + "Event Decorations", + "Event Rentals", + "Event Photography", + "Event Videography", + "Event Lighting Services", + "Event DJ Services", + "Event Catering Services", + "Event Staffing", + "Event Security Services", + "Event Ticketing Services", + "Event Transportation Services", + "Event Marketing and Promotion", + "Event Audiovisual Services", + "Event Technology Solutions", + "Event Production Services", + "Event Venue Selection", + "Event Registration Services", + "Event Signage and Branding", + "Event Graphic Design", + "Event Floral Design", + "Event Entertainment", + "Event Emcees and Hosts", + "Event Planning Consultation", + "Event Logistics Management", + "Event Budgeting and Financial Planning", + "Event Stage Design and Setup", + "Event Theme Development", + "Event Auction Services", + "Event Sponsorship Management", + "Event Public Relations", + "Event Social Media Management", + "Event Website Design and Development", + "Event Crowd Management", + "Event Exhibitor Services", + "Event Printing and Collateral Services", + "Event Risk Assessment and Management", + "Event Equipment Rentals", + "Event Health and Safety Services", + "Event Waste Management", + "Event First Aid Services", + "Event Interpretation and Translation Services", + "Event Virtual and Hybrid Solutions", + "Event Drone Photography and Videography", + "Event Projection Mapping", + "Event Fireworks and Pyrotechnics", + "Event Tent and Canopy Rentals", + "Event Flooring and Staging", + "Event Valet Parking Services", + "Event Marketing Collateral Design", + "Event Theme Party Planning", + "Event Destination Management", + "Event Team Building Activities", + "Event Press and Media Coverage", + "Event Celebrity Booking", + "Event Fashion Show Production", + "Event Awards and Recognition Programs", + "Event Live Streaming Services", + "Event Sound System Rentals", + "Event Trade Show Booth Design", + "Event Product Launch Planning", + "Event Fashion Styling and Consulting", + "Event Charity Auctions", + "Event Fashion Runway Design", + "Event Event Swag and Merchandise", + "Event Costume and Prop Rentals", + "Event Concert Production", + "Event Venue Coordination", + "Event Invitation and Stationery Design", + "Event Wine and Beverage Services", + "Event Social Event Planning", + "Event Silent Auctions", + "Event Celebrity Meet and Greet", + "Event Food and Beverage Pairing", + "Event Bar and Bartending Services", + "Event Destination Weddings", + "Event Fundraising and Development", + "Event Gala Dinners", + "Event Event Branding and Identity", + "Event Conference Planning", + "Event Team Registration and Management", + "Event Incentive Travel Planning", + "Event Event Website and App Development", + "Event Theme Park and Attraction Planning", + "Event Product Demonstrations", + "Event Inflatable Rentals", + "Event Themed Entertainment", + "Event Drone Shows", + "Event Costume Design and Creation", + "Event Run/Walk/Ride Planning", + "Event Wine Tasting and Pairing", + "Event Event App Development", + "Event Pop-Up Shop Planning", + "Event Street Marketing and Promotions", + "Event Audio and Visual Equipment Sales", + "Event VIP Experiences", + "Event Mobile App Development", + "Event Gaming and Esports Experiences" + ] + } + }, + { + "constraints": { + "business_category": "Grocery" + }, + "replacements": { + "keyword": [ + "Fresh Fruits", + "Fresh Vegetables", + "Organic Produce", + "Herbs and Spices", + "Dairy Products", + "Eggs", + "Butter and Margarine", + "Milk", + "Yogurt", + "Cheese", + "Deli Meats", + "Fresh Bakery Products", + "Breads", + "Rolls and Bagels", + "Cakes and Pastries", + "Gluten-Free Products", + "Frozen Foods", + "Frozen Vegetables", + "Frozen Fruits", + "Frozen Meals", + "Ice Cream and Frozen Desserts", + "Canned Goods", + "Canned Fruits", + "Canned Vegetables", + "Canned Soups", + "Canned Beans", + "Canned Fish and Seafood", + "Condiments and Sauces", + "Ketchup and Mustard", + "Mayonnaise and Salad Dressings", + "Cooking Oils", + "Vinegars", + "Pasta and Noodles", + "Rice", + "Grains and Lentils", + "Breakfast Cereals", + "Oatmeal and Porridge", + "Pancake and Waffle Mixes", + "Snack Foods", + "Chips and Pretzels", + "Crackers", + "Popcorn", + "Nuts and Seeds", + "Dried Fruits", + "Chocolate and Candy", + "Beverages", + "Soft Drinks", + "Juices", + "Coffee", + "Tea", + "Bottled Water", + "Energy Drinks", + "Alcoholic Beverages", + "Beer", + "Wine", + "Spirits and Liquors", + "Baby Products", + "Baby Food", + "Diapers", + "Baby Formula", + "Personal Care Products", + "Shampoo and Conditioner", + "Soap and Body Wash", + "Toothpaste and Toothbrushes", + "Feminine Hygiene Products", + "Deodorant", + "Laundry Supplies", + "Laundry Detergent", + "Fabric Softener", + "Stain Removers", + "Paper Products", + "Toilet Paper", + "Paper Towels", + "Napkins", + "Cleaning Supplies", + "All-Purpose Cleaners", + "Dishwashing Liquid", + "Surface Disinfectants", + "Trash Bags", + "Pet Food", + "Cat Food", + "Dog Food", + "Pet Treats", + "Pet Supplies", + "Health and Wellness Products", + "Vitamins and Supplements", + "Over-the-Counter Medications", + "First Aid Supplies", + "Oral Care Products", + "Household Essentials", + "Light Bulbs", + "Batteries", + "Gardening Supplies", + "Cooking Ingredients", + "Baking Supplies", + "Sugar and Sweeteners", + "Flour", + "Yeast", + "Spreads and Jams", + "Ethnic and International Foods" + ] + } + }, + { + "constraints": { + "business_category": "Health & Medical" + }, + "replacements": { + "keyword": [ + "Family Medicine", + "Internal Medicine", + "Pediatrics", + "Obstetrics and Gynecology", + "Cardiology", + "Dermatology", + "Endocrinology", + "Gastroenterology", + "Neurology", + "Oncology", + "Ophthalmology", + "Orthopedics", + "Otolaryngology (ENT)", + "Psychiatry", + "Pulmonology", + "Rheumatology", + "Urology", + "Allergy and Immunology", + "Anesthesiology", + "Emergency Medicine", + "Radiology", + "Pathology", + "Physical Therapy", + "Occupational Therapy", + "Speech Therapy", + "Chiropractic Care", + "Acupuncture", + "Homeopathy", + "Naturopathy", + "Nutrition and Dietetics", + "Mental Health Counseling", + "Rehabilitation Services", + "Sleep Medicine", + "Pain Management", + "Geriatric Medicine", + "Sports Medicine", + "Plastic Surgery", + "Ophthalmic Surgery", + "Cardiothoracic Surgery", + "Oral and Maxillofacial Surgery", + "Podiatry", + "Audiology", + "Optometry", + "Dentistry", + "Orthodontics", + "Periodontics", + "Prosthodontics", + "Oral Surgery", + "Medical Imaging Services", + "Laboratory Services", + "Pharmaceutical Services", + "Rehabilitation Centers", + "Assisted Living Facilities", + "Nursing Homes", + "Home Health Care Services", + "Hospice Care", + "Mental Health Hospitals", + "Cancer Centers", + "Pain Clinics", + "Rehabilitation Hospitals", + "Diagnostic Centers", + "Ambulatory Surgery Centers", + "Urgent Care Centers", + "Walk-in Clinics", + "Wellness Centers", + "Weight Loss Centers", + "Rehabilitation Centers", + "Physical Fitness Centers", + "Massage Therapy", + "Yoga and Pilates Studios", + "Health Coaching", + "Health Education Programs", + "Chronic Disease Management", + "Diabetes Care and Education", + "Asthma and Allergy Management", + "HIV/AIDS Care and Support", + "Women's Health Services", + "Men's Health Services", + "Senior Health Services", + "Pediatric Care Services", + "Sexual and Reproductive Health Services", + "Preventive Care Services", + "Vaccination Clinics", + "Pain Clinics", + "Smoking Cessation Programs", + "Weight Management Programs", + "Rehabilitation Services for Injuries", + "Genetic Counseling", + "Telemedicine Services", + "Integrative Medicine", + "Cancer Support Services", + "Stroke Rehabilitation Services", + "Dementia Care Services", + "Chronic Pain Management", + "Diabetes Education and Management", + "Physical and Occupational Rehabilitation", + "Mental Health Support Groups", + "Orthopedic Rehabilitation Services", + "Addiction Treatment Centers", + "Holistic Healing Centers" + ] + } + }, + { + "constraints": { + "business_category": "Home & Garden" + }, + "replacements": { + "keyword": [ + "Furniture", + "Home Decor", + "Lighting Fixtures", + "Kitchen Appliances", + "Bathroom Fixtures", + "Bedroom Furniture", + "Living Room Furniture", + "Dining Room Furniture", + "Home Office Furniture", + "Outdoor Furniture", + "Home Organization", + "Storage Solutions", + "Wall Art and Paintings", + "Rugs and Carpets", + "Curtains and Window Treatments", + "Throw Pillows and Cushions", + "Tableware and Dinnerware", + "Cookware and Bakeware", + "Small Kitchen Appliances", + "Home Cleaning Products", + "Laundry Supplies", + "Home Fragrances", + "Indoor Plants and Pots", + "Garden Tools", + "Outdoor Grills and Cooking Equipment", + "Patio Furniture", + "Outdoor Lighting", + "Lawn and Garden Care", + "Gardening Accessories", + "Fertilizers and Soil Amendments", + "Outdoor Decorations", + "Planters and Raised Beds", + "Swimming Pool Supplies", + "Hot Tubs and Spas", + "Outdoor Power Equipment", + "Pest Control Products", + "Home Security Systems", + "Smart Home Devices", + "Home Improvement Tools", + "Power Tools", + "Hand Tools", + "Plumbing Supplies", + "Electrical Supplies", + "Paint and Painting Supplies", + "Flooring Materials", + "Tiles and Mosaics", + "Wallpaper and Wall Coverings", + "Home Renovation Services", + "Kitchen Renovation Services", + "Bathroom Renovation Services", + "Flooring Installation Services", + "Interior Design Services", + "Landscape Design Services", + "Home Theater Systems", + "Home Office Equipment", + "Home Gym Equipment", + "Outdoor Playsets and Swing Sets", + "Fireplaces and Wood Stoves", + "Home Energy Efficiency Products", + "Solar Panels and Systems", + "Roofing Materials", + "Doors and Windows", + "Garage Doors and Openers", + "Home Insulation", + "HVAC Systems", + "Ceiling Fans and Ventilation", + "Home Water Filtration Systems", + "Home Security Cameras", + "Lawn Mowers and Tractors", + "Gardening Books and Magazines", + "Bird Feeders and Houses", + "Composting Supplies", + "Garden Irrigation Systems", + "Outdoor Storage Sheds", + "Deck and Patio Materials", + "Outdoor Structures (Gazebos, Pergolas)", + "Lawn and Garden Decor", + "BBQ and Grill Accessories", + "Home Maintenance and Repair Services", + "Carpet Cleaning Services", + "Furniture Restoration Services", + "Home Painting Services", + "Plumbing Services", + "Electrical Services", + "Roofing Services", + "Home Cleaning Services", + "Home Staging Services", + "Pest Control Services", + "Landscaping Services", + "Tree Care and Arborist Services", + "Home Inspection Services", + "Interior Decorating Services", + "Garage Organization Systems", + "Closet and Storage Solutions", + "Home Audio Systems", + "Home Automation Systems", + "Greenhouse Supplies", + "Aquatic Plants and Pond Supplies", + "Outdoor Sound Systems", + "Holiday Decorations" + ] + } + }, + { + "constraints": { + "business_category": "Hotels & Travel" + }, + "replacements": { + "keyword": [ + "Luxury Hotels", + "Boutique Hotels", + "Budget Hotels", + "Resort Hotels", + "Bed and Breakfasts", + "Hostels", + "Vacation Rentals", + "Beachfront Hotels", + "Mountain Resorts", + "Spa Resorts", + "Casino Hotels", + "Pet-Friendly Hotels", + "Family-Friendly Hotels", + "Eco-Friendly Hotels", + "All-Inclusive Resorts", + "Business Hotels", + "Airport Hotels", + "Historic Hotels", + "Honeymoon Resorts", + "Ski Resorts", + "Golf Resorts", + "Safari Lodges", + "Wellness Retreats", + "Glamping Sites", + "RV Parks and Campgrounds", + "Cruise Ships", + "Tourist Attractions", + "Sightseeing Tours", + "Adventure Travel", + "Cultural Tours", + "Food and Wine Tours", + "Outdoor Activities", + "Beach Vacations", + "City Breaks", + "Mountain Escapes", + "Hiking and Trekking Trips", + "Wildlife Safaris", + "Scuba Diving Excursions", + "Snorkeling Tours", + "Whale Watching Trips", + "Surfing Camps", + "Yoga Retreats", + "Wellness and Spa Getaways", + "Wine Country Tours", + "Culinary Vacations", + "Historical Landmarks", + "Theme Parks", + "National Parks", + "UNESCO World Heritage Sites", + "Museums and Galleries", + "Amusement Parks", + "Water Parks", + "Zoos and Animal Sanctuaries", + "Botanical Gardens", + "Cultural Festivals", + "Music Festivals", + "Food Festivals", + "Art and Craft Fairs", + "Sports Events", + "Concerts and Live Performances", + "Destination Weddings", + "Honeymoon Packages", + "Group Tours", + "Family Vacations", + "Adventure Sports", + "Scenic Drives", + "Road Trips", + "River Cruises", + "Ocean Cruises", + "Luxury Train Journeys", + "Helicopter Tours", + "Hot Air Balloon Rides", + "Wildlife Photography Tours", + "Safari Expeditions", + "Desert Adventures", + "Cultural Immersion Programs", + "Language Learning Vacations", + "Volunteer Tourism", + "Educational Tours", + "Sustainable Travel Experiences", + "Solo Travel Adventures", + "Backpacking Trips", + "Student Travel Packages", + "Group Retreats", + "Family Reunions", + "Business Travel Services", + "Event Planning Services", + "Travel Insurance Services", + "Car Rental Services", + "Airport Shuttle Services", + "Guided City Tours", + "Adventure Travel Gear", + "Travel Photography Workshops", + "Travel Blogging and Vlogging", + "Travel Accessories and Luggage", + "Travel Booking Websites", + "Travel Agents and Consultants", + "Travel Rewards Programs", + "Travel Safety and Security", + "Travel Destination Guides" + ] + } + }, + { + "constraints": { + "business_category": "Restaurants" + }, + "replacements": { + "keyword": [ + "Fine Dining", + "Casual Dining", + "Fast Food", + "Cafés", + "Pizzerias", + "Seafood Restaurants", + "Steakhouse", + "Sushi Bars", + "Buffets", + "Vegetarian/Vegan Restaurants", + "Ethnic Cuisine (e.g., Italian, Mexican, Chinese)", + "Barbecue Restaurants", + "Food Trucks", + "Family-Friendly Restaurants", + "Diners", + "Bistros", + "Brunch Spots", + "Dessert Shops", + "Ice Cream Parlors", + "Bakeries", + "Food Court", + "Food Stalls", + "Food Delivery Services", + "Breweries with Food", + "Gastropubs", + "Tapas Bars", + "Farm-to-Table Restaurants", + "Organic Restaurants", + "Gluten-Free Restaurants", + "Street Food Vendors", + "Sandwich Shops", + "Salad Bars", + "Juice Bars", + "Oyster Bars", + "Ramen Shops", + "Noodle Houses", + "Fondue Restaurants", + "Burger Joints", + "Taco Stands", + "Mediterranean Restaurants", + "Indian Restaurants", + "Thai Restaurants", + "Korean Barbecue", + "Vegan Bakeries", + "Fondue Restaurants", + "French Restaurants", + "Cajun/Creole Restaurants", + "Brazilian Steakhouses", + "Teppanyaki Restaurants", + "Gastrobars", + "Hot Dog Stands", + "Waffle Houses", + "Bagel Shops", + "Taprooms with Food", + "Gourmet Food Trucks", + "Mongolian BBQ", + "Delis", + "Dim Sum Restaurants", + "Lebanese Restaurants", + "Ethiopian Restaurants", + "Malaysian Restaurants", + "Caribbean Restaurants", + "Irish Pubs with Food", + "Vietnamese Pho Restaurants", + "Oyster Bars", + "Spanish Tapas Restaurants", + "Vegetarian Sushi Restaurants", + "Greek Tavernas", + "Brazilian Rodizio", + "Colombian Restaurants", + "Cuban Cafeterias", + "Indonesian Restaurants", + "Moroccan Restaurants", + "Peruvian Restaurants", + "Middle Eastern Meze Restaurants", + "Russian Restaurants", + "Belgian Waffle Houses", + "Fish and Chip Shops", + "Egyptian Restaurants", + "Nigerian Restaurants", + "Uzbek Cuisine Restaurants", + "Hawaiian Poke Restaurants", + "Texas Barbecue Joints", + "Southern Fried Chicken Restaurants", + "Filipino Restaurants", + "Turkish Kebab Houses", + "Israeli Falafel Stands", + "Scandinavian Smorgasbords", + "Argentine Parrillas", + "British Pubs with Food", + "Cambodian Restaurants", + "Czech Restaurants", + "Polish Pierogi Houses", + "Jamaican Jerk Chicken Stands", + "Mongolian Hot Pot Restaurants", + "Swiss Fondue Chalets", + "Guatemalan Restaurants", + "Nepalese Restaurants", + "Ecuadorian Restaurants", + "Bolivian Restaurants" + ] + } + }, + { + "constraints": { + "business_category": "Shopping" + }, + "replacements": { + "keyword": [ + "Clothing Stores", + "Shoe Stores", + "Accessories Stores", + "Jewelry Stores", + "Department Stores", + "Electronics Stores", + "Appliance Stores", + "Furniture Stores", + "Home Decor Stores", + "Sporting Goods Stores", + "Outdoor Gear Stores", + "Beauty and Cosmetics Stores", + "Health and Wellness Stores", + "Baby and Kids Stores", + "Toy Stores", + "Books and Stationery Stores", + "Music and Video Stores", + "Art and Craft Stores", + "Hobby and Collectibles Stores", + "Home Improvement Stores", + "Garden and Nursery Stores", + "Grocery Stores", + "Supermarkets", + "Organic Food Stores", + "Specialty Food Stores", + "Wine and Liquor Stores", + "Farmers Markets", + "Pet Stores", + "Office Supplies Stores", + "Computer and Electronics Stores", + "Cell Phone Stores", + "Gaming Stores", + "Vintage and Thrift Stores", + "Antique Stores", + "Flea Markets", + "Online Marketplaces", + "Auction Websites", + "Discount Stores", + "Outlet Malls", + "Shopping Centers", + "Luxury Brand Stores", + "Designer Boutiques", + "Custom Tailoring Stores", + "Maternity Stores", + "Plus-Size Clothing Stores", + "Men's Clothing Stores", + "Women's Clothing Stores", + "Children's Clothing Stores", + "Swimwear Stores", + "Lingerie Stores", + "Formal Wear Stores", + "Athletic Wear Stores", + "Sneaker Stores", + "Handbag Stores", + "Sunglasses Stores", + "Watch Stores", + "Fine Jewelry Stores", + "Costume Jewelry Stores", + "Vintage Jewelry Stores", + "Bridal Stores", + "Shoe Repair Shops", + "Athletic Shoe Stores", + "Luxury Shoe Stores", + "Sneaker Boutiques", + "Hat Stores", + "Scarf and Accessory Stores", + "Sock Stores", + "Beauty Supply Stores", + "Skincare Stores", + "Haircare Stores", + "Makeup Stores", + "Perfume Stores", + "Health Food Stores", + "Vitamin and Supplement Stores", + "Fitness Equipment Stores", + "Baby Clothing Stores", + "Baby Gear Stores", + "Toy Stores", + "Educational Toy Stores", + "Board Game Stores", + "Comic Book Stores", + "Music Stores", + "Movie Stores", + "Art Supply Stores", + "Craft Stores", + "Sewing Stores", + "Home Improvement Stores", + "Tools Stores", + "Paint Stores", + "Lighting Stores", + "Garden Supply Stores", + "Plant Stores", + "Organic Food Stores", + "Local Farmers Markets", + "Pet Food Stores", + "Pet Supply Stores", + "Office Supply Stores", + "Paper and Stationery Stores", + "Technology Stores", + "Online Retailers" + ] + } + } + ] +} \ No newline at end of file diff --git a/example/text/yelp_huggingface/variation_api_prompt.json b/example/text/yelp_huggingface/variation_api_prompt.json new file mode 100644 index 0000000..9de6575 --- /dev/null +++ b/example/text/yelp_huggingface/variation_api_prompt.json @@ -0,0 +1,45 @@ +{ + "message_template": [ + { + "role": "user", + "content": "Based on Business Category: {business_category}\tReview Stars: {review_stars}, please rephrase the following sentences {tone}:\n{sample}" + } + ], + "replacement_rules": [ + { + "constraints": {}, + "replacements": { + "tone": [ + "in a casual way", + "in a creative style", + "in an informal way", + "casually", + "in a detailed way", + "in a professional way", + "with more details", + "with a professional tone", + "in a casual style", + "in a professional style", + "in a short way", + "in a concise manner", + "concisely", + "briefly", + "orally", + "with imagination", + "with a tone of earnestness", + "in a grammarly-incorrect way", + "with grammatical errors", + "in a non-standard grammar fashion", + "in an oral way", + "in a spoken manner", + "articulately", + "by word of mouth", + "in a storytelling tone", + "in a formal manner", + "with an informal tone", + "in a laid-back manner" + ] + } + } + ] +} \ No newline at end of file diff --git a/example/text/yelp_openai/main.py b/example/text/yelp_openai/main.py new file mode 100644 index 0000000..5fae6b1 --- /dev/null +++ b/example/text/yelp_openai/main.py @@ -0,0 +1,106 @@ +""" +This example follows the experimental settings of the GPT-3.5 Yelp experiments in the ICML 2024 Spotlight paper, +"Differentially Private Synthetic Data via Foundation Model APIs 2: Text" (https://arxiv.org/abs/2403.01749), except +that the model is changed from GPT-3.5 to GPT-4o-mini-2024-07-18 as the original GPT-3.5 model version used in the +paper is no longer available. + +To run the code, the following environment variables are required: +* OPENAI_API_KEY: OpenAI API key. You can get it from https://platform.openai.com/account/api-keys. Multiple keys can + be separated by commas, and a key will be selected randomly for each request. + +We can also switch from OpenAI API to Azure OpenAI API by using :py:class:`pe.llm.azure_openai.AzureOpenAILLM` instead +of :py:class:`pe.llm.openai.OpenAILLM`. In that case, the following environment variables are required: +* AZURE_OPENAI_API_KEY: Azure OpenAI API key. You can get it from https://portal.azure.com/. Multiple keys can be + separated by commas, and a key will be selected randomly for each request. The key can also be "AZ_CLI", in which + case the Azure CLI will be used to authenticate the requests, and the environment variable AZURE_OPENAI_API_SCOPE + needs to be set. See Azure OpenAI authentication documentation for more information: + https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints#microsoft-entra-id-authentication +* AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint. You can get it from https://portal.azure.com/. +* AZURE_OPENAI_API_VERSION: Azure OpenAI API version. You can get it from https://portal.azure.com/. + +These environment variables can be set in a .env file in the same directory as this script. For example: +``` +OPENAI_API_KEY=your_openai_api_key +``` +See https://github.com/theskumar/python-dotenv for more information about the .env file. + +For detailed information about parameters and APIs, please consult the documentation of the Private Evolution library: +https://microsoft.github.io/DPSDA/. +""" + +from dotenv import load_dotenv + +from pe.data.text import Yelp +from pe.logging import setup_logging +from pe.runner import PE +from pe.population import PEPopulation +from pe.api.text import LLMAugPE +from pe.llm import OpenAILLM +from pe.embedding.text import SentenceTransformer +from pe.histogram import NearestNeighbors +from pe.callback import SaveCheckpoints +from pe.callback import ComputeFID +from pe.callback import SaveTextToCSV +from pe.logger import CSVPrint +from pe.logger import LogPrint + +import pandas as pd +import os +import numpy as np + +pd.options.mode.copy_on_write = True + + +if __name__ == "__main__": + exp_folder = "results/text/yelp_openai_api" + current_folder = os.path.dirname(os.path.abspath(__file__)) + + load_dotenv() + + setup_logging(log_file=os.path.join(exp_folder, "log.txt")) + + data = Yelp(root_dir="/tmp/data/yelp") + llm = OpenAILLM(max_completion_tokens=128, model="gpt-4o-mini-2024-07-18", temperature=1.4, num_threads=4) + api = LLMAugPE( + llm=llm, + random_api_prompt_file=os.path.join(current_folder, "random_api_prompt.json"), + variation_api_prompt_file=os.path.join(current_folder, "variation_api_prompt.json"), + min_word_count=25, + word_count_std=20, + token_to_word_ratio=1.2, + max_completion_tokens_limit=1200, + blank_probabilities=0.5, + ) + embedding = SentenceTransformer(model="stsb-roberta-base-v2") + histogram = NearestNeighbors( + embedding=embedding, + mode="L2", + lookahead_degree=0, + ) + population = PEPopulation( + api=api, initial_variation_api_fold=3, next_variation_api_fold=3, keep_selected=True, selection_mode="rank" + ) + + save_checkpoints = SaveCheckpoints(os.path.join(exp_folder, "checkpoint")) + compute_fid = ComputeFID(priv_data=data, embedding=embedding) + save_text_to_csv = SaveTextToCSV(output_folder=os.path.join(exp_folder, "synthetic_text")) + + csv_print = CSVPrint(output_folder=exp_folder) + log_print = LogPrint() + + num_private_samples = len(data.data_frame) + delta = 1.0 / num_private_samples / np.log(num_private_samples) + + pe_runner = PE( + priv_data=data, + population=population, + histogram=histogram, + callbacks=[save_checkpoints, save_text_to_csv, compute_fid], + loggers=[csv_print, log_print], + ) + pe_runner.run( + num_samples_schedule=[5000] * 21, + delta=delta, + epsilon=1.0, + checkpoint_path=os.path.join(exp_folder, "checkpoint"), + ) diff --git a/example/text/yelp_openai/random_api_prompt.json b/example/text/yelp_openai/random_api_prompt.json new file mode 100644 index 0000000..1a542cf --- /dev/null +++ b/example/text/yelp_openai/random_api_prompt.json @@ -0,0 +1,1104 @@ +{ + "message_template": [ + { + "role": "system", + "content": "You are required to write an example of review based on the provided Business Category and Review Stars that fall within the range of 1.0-5.0." + }, + { + "role": "user", + "content": "Business Category: {business_category}\tReview Stars: {review_stars} with keyword {keyword}" + } + ], + "replacement_rules": [ + { + "constraints": { + "business_category": "Arts & Entertainment" + }, + "replacements": { + "keyword": [ + "Art Galleries", + "Museums", + "Live Music Venues", + "Theaters", + "Dance Studios", + "Comedy Clubs", + "Film Festivals", + "Performing Arts Centers", + "Concert Halls", + "Jazz Clubs", + "Opera Houses", + "Symphony Orchestras", + "Ballet Companies", + "Art Exhibitions", + "Street Performances", + "Improv Shows", + "Stand-Up Comedy", + "Music Festivals", + "Film Screening Events", + "Art Workshops", + "Art Classes", + "Art Installations", + "Art Fairs", + "Sculpture Gardens", + "Public Art Displays", + "Art Auctions", + "Art Museums", + "Contemporary Art Spaces", + "Ceramic Studios", + "Pottery Classes", + "Photography Exhibitions", + "Street Art Tours", + "Graffiti Art", + "Poetry Readings", + "Literary Festivals", + "Bookstores", + "Storytelling Events", + "Cabaret Shows", + "Magic Shows", + "Circus Performances", + "Puppet Shows", + "Fashion Shows", + "Body Painting", + "Burlesque Shows", + "Drag Performances", + "Spoken Word Performances", + "Opera Performances", + "Outdoor Concerts", + "Tribute Bands", + "Music Open Mic Nights", + "Indie Music Venues", + "Jazz Festivals", + "Theatre Festivals", + "Dance Performances", + "Dance Workshops", + "Dance Competitions", + "Film Premiers", + "Film Screenwriting Workshops", + "Film Production Studios", + "Animation Studios", + "Film Awards Ceremonies", + "Film Retrospectives", + "Film Noir Screenings", + "Film Documentaries", + "Independent Film Showcases", + "Film Director Q&A Sessions", + "Art History Lectures", + "Art Tours", + "Art Therapy Workshops", + "Art Supply Stores", + "Art Conservation Services", + "Public Murals", + "Live Street Art Performances", + "Art Film Screenings", + "Art Restoration Services", + "Art Book Signings", + "Art Magazine Launches", + "Artisan Markets", + "Artisan Workshops", + "DIY Craft Events", + "Virtual Reality Experiences", + "Gaming Conventions", + "Esports Tournaments", + "Cosplay Events", + "Comic Book Conventions", + "Anime Festivals", + "Pop Culture Expos", + "Board Game Cafés", + "Trivia Nights", + "Karaoke Bars", + "Outdoor Theater Performances", + "Artisan Food and Beverage Tastings", + "Historic Site Tours", + "Music Education Programs", + "Film Soundtrack Concerts", + "Art Performance Installations", + "Artisanal Food Markets", + "Circus Training Workshops", + "Live Painting Demonstrations", + "Art Film Festivals" + ] + } + }, + { + "constraints": { + "business_category": "Bars" + }, + "replacements": { + "keyword": [ + "Sports Bars", + "Dive Bars", + "Cocktail Bars", + "Brewpubs", + "Wine Bars", + "Karaoke Bars", + "Jazz Bars", + "Tiki Bars", + "Rooftop Bars", + "Irish Pubs", + "Whiskey Bars", + "Beer Gardens", + "Speakeasies", + "Neighborhood Bars", + "Gay Bars", + "Salsa Bars", + "Cigar Bars", + "Piano Bars", + "Country Bars", + "College Bars", + "Hotel Bars", + "Live Music Bars", + "Beach Bars", + "Craft Beer Bars", + "Comedy Clubs with Bars", + "Tequila Bars", + "Rum Bars", + "Gin Bars", + "Martini Bars", + "Bourbon Bars", + "Scotch Bars", + "Blues Bars", + "Reggae Bars", + "Wine Tasting Bars", + "Distillery Bars", + "Outdoor Bars", + "Latin Bars", + "Upscale Bars", + "Lounge Bars", + "Artisanal Cocktail Bars", + "Whiskey Tasting Bars", + "Cider Bars", + "Hipster Bars", + "Underground Bars", + "Cabaret Bars", + "Burlesque Bars", + "Board Game Bars", + "Arcade Bars", + "Craft Cocktail Bars", + "Sake Bars", + "Biker Bars", + "Tapas Bars", + "Microbrewery Taprooms", + "Speakeasy-Style Bars", + "Absinthe Bars", + "Vodka Bars", + "Beachfront Bars", + "80s Bars", + "90s Bars", + "Swanky Bars", + "Rum Tasting Bars", + "Irish Whiskey Bars", + "Sours Bars", + "Whiskey and Cigar Bars", + "Themed Bars", + "Mezcal Bars", + "Hawaiian Tiki Bars", + "German Beer Halls", + "Rooftop Sky Bars", + "Rustic Bars", + "Wine Cellar Bars", + "Gin and Tonic Bars", + "Underground Speakeasies", + "Jazz Speakeasies", + "Secret Bars", + "Piano Karaoke Bars", + "Reggaeton Bars", + "Bachelorette Party Bars", + "Outdoor Rooftop Bars", + "Blues and BBQ Bars", + "Rum Distillery Bars", + "Cocktail Mixology Bars", + "Margarita Bars", + "Classic Cocktail Bars", + "Nightclub Bars", + "Whiskey and BBQ Bars", + "Cabana Bars", + "Rooftop Pool Bars", + "Drag Bars", + "Wine and Cheese Bars", + "Prohibition-Style Bars", + "Tropical Bars", + "Latin Dance Bars", + "Rum Tiki Bars", + "Rooftop Lounge Bars", + "Beer Flight Bars", + "Barrel-Aged Beer Bars", + "Tropical Cocktail Bars", + "Jazz and Blues Bars", + "Outdoor Beach Bars" + ] + } + }, + { + "constraints": { + "business_category": "Beauty & Spas" + }, + "replacements": { + "keyword": [ + "Hair Salons", + "Nail Salons", + "Day Spas", + "Massage Therapy", + "Facial Services", + "Waxing Studios", + "Eyebrow Threading", + "Makeup Artists", + "Barber Shops", + "Hair Removal Services", + "Tanning Salons", + "Body Treatments", + "Manicure and Pedicure Services", + "Spa Packages", + "Skin Care Clinics", + "Medical Spas", + "Acupuncture Services", + "Ayurvedic Treatments", + "Aromatherapy", + "Reflexology Services", + "Reiki Healing", + "Holistic Wellness Centers", + "Hair Color Services", + "Hair Extensions", + "Bridal Hair and Makeup Services", + "Lash Extensions", + "Microblading Services", + "Tattoo Studios", + "Permanent Makeup", + "Blowout Services", + "Scalp Treatments", + "Brazilian Waxing", + "Eyelash Lift and Tint", + "Nail Art Studios", + "Gel Nail Services", + "Hair Braiding", + "Spa Facials", + "Hot Stone Massages", + "Deep Tissue Massages", + "Swedish Massages", + "Sports Massages", + "Thai Massages", + "Couples Massages", + "Body Scrubs", + "Body Wraps", + "Sauna Services", + "Floatation Therapy", + "Cupping Therapy", + "Oxygen Bar Services", + "Spa Manicures and Pedicures", + "Shellac Nail Services", + "Gel Polish Removal", + "Foot Massage Services", + "Back Massage Services", + "Anti-Aging Treatments", + "Chemical Peels", + "Microdermabrasion", + "Laser Hair Removal", + "Botox Injections", + "Dermal Fillers", + "Lip Enhancement", + "Facial Rejuvenation", + "Coolsculpting Services", + "Fat Reduction Treatments", + "Body Contouring", + "Cryotherapy Services", + "Spa Parties", + "Mobile Beauty Services", + "Wellness Retreats", + "Meditation Classes", + "Yoga Studios", + "Pilates Studios", + "Fitness Centers with Spa Services", + "Infrared Sauna Therapy", + "Body Piercing Studios", + "Non-Surgical Facelifts", + "Vampire Facials", + "Microneedling Services", + "Teeth Whitening Services", + "Hair Straightening Services", + "Balayage Services", + "Highlights and Lowlights", + "Keratin Treatments", + "Scalp Micropigmentation", + "Henna Tattoos", + "Laser Skin Resurfacing", + "Body Piercing Jewelry and Accessories", + "Spa Membership Programs", + "Esthetician Services", + "Reflexology Foot Spas", + "Hand and Arm Massages", + "Deep Cleansing Facials", + "Couples Spa Packages", + "Lymphatic Drainage Massages", + "Eyebrow and Eyelash Tinting", + "Beard Grooming Services", + "Spa Consultations", + "Wellness Coaching", + "Hair Loss Treatments", + "Herbal Wraps" + ] + } + }, + { + "constraints": { + "business_category": "Event Planning & Services" + }, + "replacements": { + "keyword": [ + "Wedding Planning", + "Corporate Event Planning", + "Party Planning", + "Event Decorations", + "Event Rentals", + "Event Photography", + "Event Videography", + "Event Lighting Services", + "Event DJ Services", + "Event Catering Services", + "Event Staffing", + "Event Security Services", + "Event Ticketing Services", + "Event Transportation Services", + "Event Marketing and Promotion", + "Event Audiovisual Services", + "Event Technology Solutions", + "Event Production Services", + "Event Venue Selection", + "Event Registration Services", + "Event Signage and Branding", + "Event Graphic Design", + "Event Floral Design", + "Event Entertainment", + "Event Emcees and Hosts", + "Event Planning Consultation", + "Event Logistics Management", + "Event Budgeting and Financial Planning", + "Event Stage Design and Setup", + "Event Theme Development", + "Event Auction Services", + "Event Sponsorship Management", + "Event Public Relations", + "Event Social Media Management", + "Event Website Design and Development", + "Event Crowd Management", + "Event Exhibitor Services", + "Event Printing and Collateral Services", + "Event Risk Assessment and Management", + "Event Equipment Rentals", + "Event Health and Safety Services", + "Event Waste Management", + "Event First Aid Services", + "Event Interpretation and Translation Services", + "Event Virtual and Hybrid Solutions", + "Event Drone Photography and Videography", + "Event Projection Mapping", + "Event Fireworks and Pyrotechnics", + "Event Tent and Canopy Rentals", + "Event Flooring and Staging", + "Event Valet Parking Services", + "Event Marketing Collateral Design", + "Event Theme Party Planning", + "Event Destination Management", + "Event Team Building Activities", + "Event Press and Media Coverage", + "Event Celebrity Booking", + "Event Fashion Show Production", + "Event Awards and Recognition Programs", + "Event Live Streaming Services", + "Event Sound System Rentals", + "Event Trade Show Booth Design", + "Event Product Launch Planning", + "Event Fashion Styling and Consulting", + "Event Charity Auctions", + "Event Fashion Runway Design", + "Event Event Swag and Merchandise", + "Event Costume and Prop Rentals", + "Event Concert Production", + "Event Venue Coordination", + "Event Invitation and Stationery Design", + "Event Wine and Beverage Services", + "Event Social Event Planning", + "Event Silent Auctions", + "Event Celebrity Meet and Greet", + "Event Food and Beverage Pairing", + "Event Bar and Bartending Services", + "Event Destination Weddings", + "Event Fundraising and Development", + "Event Gala Dinners", + "Event Event Branding and Identity", + "Event Conference Planning", + "Event Team Registration and Management", + "Event Incentive Travel Planning", + "Event Event Website and App Development", + "Event Theme Park and Attraction Planning", + "Event Product Demonstrations", + "Event Inflatable Rentals", + "Event Themed Entertainment", + "Event Drone Shows", + "Event Costume Design and Creation", + "Event Run/Walk/Ride Planning", + "Event Wine Tasting and Pairing", + "Event Event App Development", + "Event Pop-Up Shop Planning", + "Event Street Marketing and Promotions", + "Event Audio and Visual Equipment Sales", + "Event VIP Experiences", + "Event Mobile App Development", + "Event Gaming and Esports Experiences" + ] + } + }, + { + "constraints": { + "business_category": "Grocery" + }, + "replacements": { + "keyword": [ + "Fresh Fruits", + "Fresh Vegetables", + "Organic Produce", + "Herbs and Spices", + "Dairy Products", + "Eggs", + "Butter and Margarine", + "Milk", + "Yogurt", + "Cheese", + "Deli Meats", + "Fresh Bakery Products", + "Breads", + "Rolls and Bagels", + "Cakes and Pastries", + "Gluten-Free Products", + "Frozen Foods", + "Frozen Vegetables", + "Frozen Fruits", + "Frozen Meals", + "Ice Cream and Frozen Desserts", + "Canned Goods", + "Canned Fruits", + "Canned Vegetables", + "Canned Soups", + "Canned Beans", + "Canned Fish and Seafood", + "Condiments and Sauces", + "Ketchup and Mustard", + "Mayonnaise and Salad Dressings", + "Cooking Oils", + "Vinegars", + "Pasta and Noodles", + "Rice", + "Grains and Lentils", + "Breakfast Cereals", + "Oatmeal and Porridge", + "Pancake and Waffle Mixes", + "Snack Foods", + "Chips and Pretzels", + "Crackers", + "Popcorn", + "Nuts and Seeds", + "Dried Fruits", + "Chocolate and Candy", + "Beverages", + "Soft Drinks", + "Juices", + "Coffee", + "Tea", + "Bottled Water", + "Energy Drinks", + "Alcoholic Beverages", + "Beer", + "Wine", + "Spirits and Liquors", + "Baby Products", + "Baby Food", + "Diapers", + "Baby Formula", + "Personal Care Products", + "Shampoo and Conditioner", + "Soap and Body Wash", + "Toothpaste and Toothbrushes", + "Feminine Hygiene Products", + "Deodorant", + "Laundry Supplies", + "Laundry Detergent", + "Fabric Softener", + "Stain Removers", + "Paper Products", + "Toilet Paper", + "Paper Towels", + "Napkins", + "Cleaning Supplies", + "All-Purpose Cleaners", + "Dishwashing Liquid", + "Surface Disinfectants", + "Trash Bags", + "Pet Food", + "Cat Food", + "Dog Food", + "Pet Treats", + "Pet Supplies", + "Health and Wellness Products", + "Vitamins and Supplements", + "Over-the-Counter Medications", + "First Aid Supplies", + "Oral Care Products", + "Household Essentials", + "Light Bulbs", + "Batteries", + "Gardening Supplies", + "Cooking Ingredients", + "Baking Supplies", + "Sugar and Sweeteners", + "Flour", + "Yeast", + "Spreads and Jams", + "Ethnic and International Foods" + ] + } + }, + { + "constraints": { + "business_category": "Health & Medical" + }, + "replacements": { + "keyword": [ + "Family Medicine", + "Internal Medicine", + "Pediatrics", + "Obstetrics and Gynecology", + "Cardiology", + "Dermatology", + "Endocrinology", + "Gastroenterology", + "Neurology", + "Oncology", + "Ophthalmology", + "Orthopedics", + "Otolaryngology (ENT)", + "Psychiatry", + "Pulmonology", + "Rheumatology", + "Urology", + "Allergy and Immunology", + "Anesthesiology", + "Emergency Medicine", + "Radiology", + "Pathology", + "Physical Therapy", + "Occupational Therapy", + "Speech Therapy", + "Chiropractic Care", + "Acupuncture", + "Homeopathy", + "Naturopathy", + "Nutrition and Dietetics", + "Mental Health Counseling", + "Rehabilitation Services", + "Sleep Medicine", + "Pain Management", + "Geriatric Medicine", + "Sports Medicine", + "Plastic Surgery", + "Ophthalmic Surgery", + "Cardiothoracic Surgery", + "Oral and Maxillofacial Surgery", + "Podiatry", + "Audiology", + "Optometry", + "Dentistry", + "Orthodontics", + "Periodontics", + "Prosthodontics", + "Oral Surgery", + "Medical Imaging Services", + "Laboratory Services", + "Pharmaceutical Services", + "Rehabilitation Centers", + "Assisted Living Facilities", + "Nursing Homes", + "Home Health Care Services", + "Hospice Care", + "Mental Health Hospitals", + "Cancer Centers", + "Pain Clinics", + "Rehabilitation Hospitals", + "Diagnostic Centers", + "Ambulatory Surgery Centers", + "Urgent Care Centers", + "Walk-in Clinics", + "Wellness Centers", + "Weight Loss Centers", + "Rehabilitation Centers", + "Physical Fitness Centers", + "Massage Therapy", + "Yoga and Pilates Studios", + "Health Coaching", + "Health Education Programs", + "Chronic Disease Management", + "Diabetes Care and Education", + "Asthma and Allergy Management", + "HIV/AIDS Care and Support", + "Women's Health Services", + "Men's Health Services", + "Senior Health Services", + "Pediatric Care Services", + "Sexual and Reproductive Health Services", + "Preventive Care Services", + "Vaccination Clinics", + "Pain Clinics", + "Smoking Cessation Programs", + "Weight Management Programs", + "Rehabilitation Services for Injuries", + "Genetic Counseling", + "Telemedicine Services", + "Integrative Medicine", + "Cancer Support Services", + "Stroke Rehabilitation Services", + "Dementia Care Services", + "Chronic Pain Management", + "Diabetes Education and Management", + "Physical and Occupational Rehabilitation", + "Mental Health Support Groups", + "Orthopedic Rehabilitation Services", + "Addiction Treatment Centers", + "Holistic Healing Centers" + ] + } + }, + { + "constraints": { + "business_category": "Home & Garden" + }, + "replacements": { + "keyword": [ + "Furniture", + "Home Decor", + "Lighting Fixtures", + "Kitchen Appliances", + "Bathroom Fixtures", + "Bedroom Furniture", + "Living Room Furniture", + "Dining Room Furniture", + "Home Office Furniture", + "Outdoor Furniture", + "Home Organization", + "Storage Solutions", + "Wall Art and Paintings", + "Rugs and Carpets", + "Curtains and Window Treatments", + "Throw Pillows and Cushions", + "Tableware and Dinnerware", + "Cookware and Bakeware", + "Small Kitchen Appliances", + "Home Cleaning Products", + "Laundry Supplies", + "Home Fragrances", + "Indoor Plants and Pots", + "Garden Tools", + "Outdoor Grills and Cooking Equipment", + "Patio Furniture", + "Outdoor Lighting", + "Lawn and Garden Care", + "Gardening Accessories", + "Fertilizers and Soil Amendments", + "Outdoor Decorations", + "Planters and Raised Beds", + "Swimming Pool Supplies", + "Hot Tubs and Spas", + "Outdoor Power Equipment", + "Pest Control Products", + "Home Security Systems", + "Smart Home Devices", + "Home Improvement Tools", + "Power Tools", + "Hand Tools", + "Plumbing Supplies", + "Electrical Supplies", + "Paint and Painting Supplies", + "Flooring Materials", + "Tiles and Mosaics", + "Wallpaper and Wall Coverings", + "Home Renovation Services", + "Kitchen Renovation Services", + "Bathroom Renovation Services", + "Flooring Installation Services", + "Interior Design Services", + "Landscape Design Services", + "Home Theater Systems", + "Home Office Equipment", + "Home Gym Equipment", + "Outdoor Playsets and Swing Sets", + "Fireplaces and Wood Stoves", + "Home Energy Efficiency Products", + "Solar Panels and Systems", + "Roofing Materials", + "Doors and Windows", + "Garage Doors and Openers", + "Home Insulation", + "HVAC Systems", + "Ceiling Fans and Ventilation", + "Home Water Filtration Systems", + "Home Security Cameras", + "Lawn Mowers and Tractors", + "Gardening Books and Magazines", + "Bird Feeders and Houses", + "Composting Supplies", + "Garden Irrigation Systems", + "Outdoor Storage Sheds", + "Deck and Patio Materials", + "Outdoor Structures (Gazebos, Pergolas)", + "Lawn and Garden Decor", + "BBQ and Grill Accessories", + "Home Maintenance and Repair Services", + "Carpet Cleaning Services", + "Furniture Restoration Services", + "Home Painting Services", + "Plumbing Services", + "Electrical Services", + "Roofing Services", + "Home Cleaning Services", + "Home Staging Services", + "Pest Control Services", + "Landscaping Services", + "Tree Care and Arborist Services", + "Home Inspection Services", + "Interior Decorating Services", + "Garage Organization Systems", + "Closet and Storage Solutions", + "Home Audio Systems", + "Home Automation Systems", + "Greenhouse Supplies", + "Aquatic Plants and Pond Supplies", + "Outdoor Sound Systems", + "Holiday Decorations" + ] + } + }, + { + "constraints": { + "business_category": "Hotels & Travel" + }, + "replacements": { + "keyword": [ + "Luxury Hotels", + "Boutique Hotels", + "Budget Hotels", + "Resort Hotels", + "Bed and Breakfasts", + "Hostels", + "Vacation Rentals", + "Beachfront Hotels", + "Mountain Resorts", + "Spa Resorts", + "Casino Hotels", + "Pet-Friendly Hotels", + "Family-Friendly Hotels", + "Eco-Friendly Hotels", + "All-Inclusive Resorts", + "Business Hotels", + "Airport Hotels", + "Historic Hotels", + "Honeymoon Resorts", + "Ski Resorts", + "Golf Resorts", + "Safari Lodges", + "Wellness Retreats", + "Glamping Sites", + "RV Parks and Campgrounds", + "Cruise Ships", + "Tourist Attractions", + "Sightseeing Tours", + "Adventure Travel", + "Cultural Tours", + "Food and Wine Tours", + "Outdoor Activities", + "Beach Vacations", + "City Breaks", + "Mountain Escapes", + "Hiking and Trekking Trips", + "Wildlife Safaris", + "Scuba Diving Excursions", + "Snorkeling Tours", + "Whale Watching Trips", + "Surfing Camps", + "Yoga Retreats", + "Wellness and Spa Getaways", + "Wine Country Tours", + "Culinary Vacations", + "Historical Landmarks", + "Theme Parks", + "National Parks", + "UNESCO World Heritage Sites", + "Museums and Galleries", + "Amusement Parks", + "Water Parks", + "Zoos and Animal Sanctuaries", + "Botanical Gardens", + "Cultural Festivals", + "Music Festivals", + "Food Festivals", + "Art and Craft Fairs", + "Sports Events", + "Concerts and Live Performances", + "Destination Weddings", + "Honeymoon Packages", + "Group Tours", + "Family Vacations", + "Adventure Sports", + "Scenic Drives", + "Road Trips", + "River Cruises", + "Ocean Cruises", + "Luxury Train Journeys", + "Helicopter Tours", + "Hot Air Balloon Rides", + "Wildlife Photography Tours", + "Safari Expeditions", + "Desert Adventures", + "Cultural Immersion Programs", + "Language Learning Vacations", + "Volunteer Tourism", + "Educational Tours", + "Sustainable Travel Experiences", + "Solo Travel Adventures", + "Backpacking Trips", + "Student Travel Packages", + "Group Retreats", + "Family Reunions", + "Business Travel Services", + "Event Planning Services", + "Travel Insurance Services", + "Car Rental Services", + "Airport Shuttle Services", + "Guided City Tours", + "Adventure Travel Gear", + "Travel Photography Workshops", + "Travel Blogging and Vlogging", + "Travel Accessories and Luggage", + "Travel Booking Websites", + "Travel Agents and Consultants", + "Travel Rewards Programs", + "Travel Safety and Security", + "Travel Destination Guides" + ] + } + }, + { + "constraints": { + "business_category": "Restaurants" + }, + "replacements": { + "keyword": [ + "Fine Dining", + "Casual Dining", + "Fast Food", + "Cafés", + "Pizzerias", + "Seafood Restaurants", + "Steakhouse", + "Sushi Bars", + "Buffets", + "Vegetarian/Vegan Restaurants", + "Ethnic Cuisine (e.g., Italian, Mexican, Chinese)", + "Barbecue Restaurants", + "Food Trucks", + "Family-Friendly Restaurants", + "Diners", + "Bistros", + "Brunch Spots", + "Dessert Shops", + "Ice Cream Parlors", + "Bakeries", + "Food Court", + "Food Stalls", + "Food Delivery Services", + "Breweries with Food", + "Gastropubs", + "Tapas Bars", + "Farm-to-Table Restaurants", + "Organic Restaurants", + "Gluten-Free Restaurants", + "Street Food Vendors", + "Sandwich Shops", + "Salad Bars", + "Juice Bars", + "Oyster Bars", + "Ramen Shops", + "Noodle Houses", + "Fondue Restaurants", + "Burger Joints", + "Taco Stands", + "Mediterranean Restaurants", + "Indian Restaurants", + "Thai Restaurants", + "Korean Barbecue", + "Vegan Bakeries", + "Fondue Restaurants", + "French Restaurants", + "Cajun/Creole Restaurants", + "Brazilian Steakhouses", + "Teppanyaki Restaurants", + "Gastrobars", + "Hot Dog Stands", + "Waffle Houses", + "Bagel Shops", + "Taprooms with Food", + "Gourmet Food Trucks", + "Mongolian BBQ", + "Delis", + "Dim Sum Restaurants", + "Lebanese Restaurants", + "Ethiopian Restaurants", + "Malaysian Restaurants", + "Caribbean Restaurants", + "Irish Pubs with Food", + "Vietnamese Pho Restaurants", + "Oyster Bars", + "Spanish Tapas Restaurants", + "Vegetarian Sushi Restaurants", + "Greek Tavernas", + "Brazilian Rodizio", + "Colombian Restaurants", + "Cuban Cafeterias", + "Indonesian Restaurants", + "Moroccan Restaurants", + "Peruvian Restaurants", + "Middle Eastern Meze Restaurants", + "Russian Restaurants", + "Belgian Waffle Houses", + "Fish and Chip Shops", + "Egyptian Restaurants", + "Nigerian Restaurants", + "Uzbek Cuisine Restaurants", + "Hawaiian Poke Restaurants", + "Texas Barbecue Joints", + "Southern Fried Chicken Restaurants", + "Filipino Restaurants", + "Turkish Kebab Houses", + "Israeli Falafel Stands", + "Scandinavian Smorgasbords", + "Argentine Parrillas", + "British Pubs with Food", + "Cambodian Restaurants", + "Czech Restaurants", + "Polish Pierogi Houses", + "Jamaican Jerk Chicken Stands", + "Mongolian Hot Pot Restaurants", + "Swiss Fondue Chalets", + "Guatemalan Restaurants", + "Nepalese Restaurants", + "Ecuadorian Restaurants", + "Bolivian Restaurants" + ] + } + }, + { + "constraints": { + "business_category": "Shopping" + }, + "replacements": { + "keyword": [ + "Clothing Stores", + "Shoe Stores", + "Accessories Stores", + "Jewelry Stores", + "Department Stores", + "Electronics Stores", + "Appliance Stores", + "Furniture Stores", + "Home Decor Stores", + "Sporting Goods Stores", + "Outdoor Gear Stores", + "Beauty and Cosmetics Stores", + "Health and Wellness Stores", + "Baby and Kids Stores", + "Toy Stores", + "Books and Stationery Stores", + "Music and Video Stores", + "Art and Craft Stores", + "Hobby and Collectibles Stores", + "Home Improvement Stores", + "Garden and Nursery Stores", + "Grocery Stores", + "Supermarkets", + "Organic Food Stores", + "Specialty Food Stores", + "Wine and Liquor Stores", + "Farmers Markets", + "Pet Stores", + "Office Supplies Stores", + "Computer and Electronics Stores", + "Cell Phone Stores", + "Gaming Stores", + "Vintage and Thrift Stores", + "Antique Stores", + "Flea Markets", + "Online Marketplaces", + "Auction Websites", + "Discount Stores", + "Outlet Malls", + "Shopping Centers", + "Luxury Brand Stores", + "Designer Boutiques", + "Custom Tailoring Stores", + "Maternity Stores", + "Plus-Size Clothing Stores", + "Men's Clothing Stores", + "Women's Clothing Stores", + "Children's Clothing Stores", + "Swimwear Stores", + "Lingerie Stores", + "Formal Wear Stores", + "Athletic Wear Stores", + "Sneaker Stores", + "Handbag Stores", + "Sunglasses Stores", + "Watch Stores", + "Fine Jewelry Stores", + "Costume Jewelry Stores", + "Vintage Jewelry Stores", + "Bridal Stores", + "Shoe Repair Shops", + "Athletic Shoe Stores", + "Luxury Shoe Stores", + "Sneaker Boutiques", + "Hat Stores", + "Scarf and Accessory Stores", + "Sock Stores", + "Beauty Supply Stores", + "Skincare Stores", + "Haircare Stores", + "Makeup Stores", + "Perfume Stores", + "Health Food Stores", + "Vitamin and Supplement Stores", + "Fitness Equipment Stores", + "Baby Clothing Stores", + "Baby Gear Stores", + "Toy Stores", + "Educational Toy Stores", + "Board Game Stores", + "Comic Book Stores", + "Music Stores", + "Movie Stores", + "Art Supply Stores", + "Craft Stores", + "Sewing Stores", + "Home Improvement Stores", + "Tools Stores", + "Paint Stores", + "Lighting Stores", + "Garden Supply Stores", + "Plant Stores", + "Organic Food Stores", + "Local Farmers Markets", + "Pet Food Stores", + "Pet Supply Stores", + "Office Supply Stores", + "Paper and Stationery Stores", + "Technology Stores", + "Online Retailers" + ] + } + } + ] +} \ No newline at end of file diff --git a/example/text/yelp_openai/variation_api_prompt.json b/example/text/yelp_openai/variation_api_prompt.json new file mode 100644 index 0000000..b2edff5 --- /dev/null +++ b/example/text/yelp_openai/variation_api_prompt.json @@ -0,0 +1,12 @@ +{ + "message_template": [ + { + "role": "system", + "content": "You are a helpful, pattern-following assistant." + }, + { + "role": "user", + "content": "Based on the Business Category and Review Stars, you are required to fill in the blanks in the Input sentences with grammatical errors. If there are no blanks, you are required to output the original Input sentences.\nBusiness Category: Restaurants\tReview Stars: 2.0\nInput: _ that great , terrible _ rolls and fish _ smelling _ _.\nFill-in-Blanks and your answer MUST be exactly 10 words: Not that great, terrible egg rolls and fishy smelling shrimp.\nBusiness Category: Beauty & Spas\tReview Stars: 5.0\nInput: Very clean! Staff are super friendly!!\nFill-in-Blanks and your answer MUST be exactly 6 words: Very clean! Staff are super friendly!!\nBusiness Category: Shopping\tReview Stars: 3.0\nInput: I _ in _ and stopped in for a _. I was _ surprised. Good _, nice price.\nFill-in-Blanks and your answer MUST be exactly 19 words: I was in a rush and stopped in for a mani-pedi. I was pleasantly surprised. Good service, nice price.\nBusiness Category: {business_category}\tReview Stars: {review_stars} \nInput: {masked_sample} \nFill-in-Blanks and your answer MUST be exactly {word_count} words:" + } + ] +} \ No newline at end of file diff --git a/pe/api/__init__.py b/pe/api/__init__.py index ae40d11..5def7f3 100644 --- a/pe/api/__init__.py +++ b/pe/api/__init__.py @@ -1 +1,5 @@ from .api import API +from .image import ImprovedDiffusion, ImprovedDiffusion270M, StableDiffusion +from .text import LLMAugPE + +__all__ = ["API", "ImprovedDiffusion", "ImprovedDiffusion270M", "LLMAugPE", "StableDiffusion"] diff --git a/pe/api/text/__init__.py b/pe/api/text/__init__.py new file mode 100644 index 0000000..8aa4f6e --- /dev/null +++ b/pe/api/text/__init__.py @@ -0,0 +1 @@ +from .llm_augpe_api import LLMAugPE diff --git a/pe/api/text/llm_augpe_api.py b/pe/api/text/llm_augpe_api.py new file mode 100644 index 0000000..409c88d --- /dev/null +++ b/pe/api/text/llm_augpe_api.py @@ -0,0 +1,243 @@ +import json +import random +import copy +import pandas as pd +import tiktoken +import numpy as np + +from pe.api import API +from pe.api.util import ConstantList +from pe.logging import execution_logger +from pe.data import Data +from pe.llm import Request +from pe.constant.data import TEXT_DATA_COLUMN_NAME +from pe.constant.data import LLM_REQUEST_MESSAGES_COLUMN_NAME +from pe.constant.data import LLM_PARAMETERS_COLUMN_NAME +from pe.constant.data import LABEL_ID_COLUMN_NAME + + +class LLMAugPE(API): + """The text API that uses open-source or API-based LLMs. This algorithm is initially proposed in the ICML 2024 + Spotlight paper, "Differentially Private Synthetic Data via Foundation Model APIs 2: Text" + (https://arxiv.org/abs/2403.01749)""" + + def __init__( + self, + llm, + random_api_prompt_file, + variation_api_prompt_file, + min_word_count=0, + word_count_std=None, + token_to_word_ratio=None, + max_completion_tokens_limit=None, + blank_probabilities=None, + tokenizer_model="gpt-3.5-turbo", + ): + """Constructor. + + :param llm: The LLM utilized for the random and variation generation + :type llm: :py:class:`pe.llm.llm.LLM` + :param random_api_prompt_file: The prompt file for the random API. See the explanations to + ``variation_api_prompt_file`` for the format of the prompt file + :type random_api_prompt_file: str + :param variation_api_prompt_file: The prompt file for the variation API. The file is in JSON format and + contains the following fields: + + * ``message_template``: A list of messages that will be sent to the LLM. Each message contains the + following fields: + + * ``content``: The content of the message. The content can contain variable placeholders (e.g., + {variable_name}). The variable_name can be label name in the original data that will be replaced by + the actual label value; or "sample" that will be replaced by the input text to the variation API; + or "masked_sample" that will be replaced by the masked/blanked input text to the variation API + when the blanking feature is enabled; or "word_count" that will be replaced by the target word + count of the text when the word count variation feature is enabled; or other variables + specified in the replacement rules (see below). + * ``role``: The role of the message. The role can be "system", "user", or "assistant". + * ``replacement_rules``: A list of replacement rules that will be applied one by one to update the variable + list. Each replacement rule contains the following fields: + + * ``constraints``: A dictionary of constraints that must be satisfied for the replacement rule to be + applied. The key is the variable name and the value is the variable value. + * ``replacements``: A dictionary of replacements that will be used to update the variable list if the + constraints are satisfied. The key is the variable name and the value is the variable value or a + list of variable values to choose from in a uniform random manner. + :type variation_api_prompt_file: str + :param min_word_count: The minimum word count for the variation API, defaults to 0 + :type min_word_count: int, optional + :param word_count_std: The standard deviation for the word count for the variation API. If None, the word count + variation feature is disabled and "{word_count}" variable will not be provided to the prompt. Defaults to + None + :type word_count_std: float, optional + :param token_to_word_ratio: The token to word ratio for the variation API. If not None, the maximum completion + tokens will be set to ``token_to_word_ratio`` times the target word count when the word count variation + feature is enabled. Defaults to None + :type token_to_word_ratio: float, optional + :param max_completion_tokens_limit: The maximum completion tokens limit for the variation API, defaults to None + :type max_completion_tokens_limit: int, optional + :param blank_probabilities: The token blank probabilities for the variation API utilized at each PE iteration. + If a single float is provided, the same blank probability will be used for all iterations. If None, the + blanking feature is disabled and "{masked_sample}" variable will not be provided to the prompt. Defaults + to None + :type blank_probabilities: float or list[float], optional + :param tokenizer_model: The tokenizer model used for blanking, defaults to "gpt-3.5-turbo" + :type tokenizer_model: str, optional + """ + super().__init__() + self._llm = llm + + self._random_api_prompt_file = random_api_prompt_file + with open(random_api_prompt_file, "r") as f: + self._random_api_prompt_config = json.load(f) + + self._variation_api_prompt_file = variation_api_prompt_file + with open(variation_api_prompt_file, "r") as f: + self._variation_api_prompt_config = json.load(f) + + self._min_word_count = min_word_count + self._word_count_std = word_count_std + self._token_to_word_ratio = token_to_word_ratio + self._max_completion_tokens_limit = max_completion_tokens_limit + if isinstance(blank_probabilities, list): + self._blank_probabilities = blank_probabilities + else: + self._blank_probabilities = ConstantList(blank_probabilities) + + self._encoding = tiktoken.encoding_for_model(tokenizer_model) + self._mask_token = self._encoding.encode("_")[0] + + def _construct_prompt(self, prompt_config, variables): + """Applying the replacement rules to construct the final prompt messages. + + :param prompt_config: The prompt configuration + :type prompt_config: dict + :param variables: The inital variables to be used in the prompt messages + :type variables: dict + :return: The constructed prompt messages + :rtype: list[dict] + """ + if "replacement_rules" in prompt_config: + for replacement_rule in prompt_config["replacement_rules"]: + constraints = replacement_rule["constraints"] + replacements = replacement_rule["replacements"] + satisfied = True + for key, value in constraints.items(): + if key not in variables or variables[key] != value: + satisfied = False + break + if satisfied: + for key, value in replacements.items(): + if isinstance(value, list): + value = random.choice(value) + variables[key] = value + messages = copy.deepcopy(prompt_config["message_template"]) + for message in messages: + message["content"] = message["content"].format(**variables) + return messages + + def random_api(self, label_info, num_samples): + """Generating random synthetic data. + + :param label_info: The info of the label + :type label_info: dict + :param num_samples: The number of random samples to generate + :type num_samples: int + :return: The data object of the generated synthetic data + :rtype: :py:class:`pe.data.data.Data` + """ + label_name = label_info.name + execution_logger.info(f"RANDOM API: creating {num_samples} samples for label {label_name}") + + variables = label_info.column_values + execution_logger.info("RANDOM API: producing LLM requests") + messages_list = [ + self._construct_prompt(prompt_config=self._random_api_prompt_config, variables=copy.deepcopy(variables)) + for _ in range(num_samples) + ] + requests = [Request(messages=messages) for messages in messages_list] + execution_logger.info("RANDOM API: getting LLM responses") + responses = self._llm.get_responses(requests) + execution_logger.info("RANDOM API: constructing data") + data_frame = pd.DataFrame( + { + TEXT_DATA_COLUMN_NAME: responses, + LLM_REQUEST_MESSAGES_COLUMN_NAME: [json.dumps(messages) for messages in messages_list], + LABEL_ID_COLUMN_NAME: 0, + } + ) + metadata = {"label_info": [label_info]} + execution_logger.info(f"RANDOM API: finished creating {num_samples} samples for label {label_name}") + return Data(data_frame=data_frame, metadata=metadata) + + def _blank_sample(self, sample, blank_probability): + """Blanking the input text. + + :param sample: The input text + :type sample: str + :param blank_probability: The token blank probability + :type blank_probability: float + :return: The blanked input text + :rtype: str + """ + input_ids = np.asarray(self._encoding.encode(sample)) + masked_indices = np.random.uniform(size=len(input_ids)) < blank_probability + input_ids[masked_indices] = self._mask_token + return self._encoding.decode(input_ids) + + def variation_api(self, syn_data): + """Generating variations of the synthetic data. + + :param syn_data: The data object of the synthetic data + :type syn_data: :py:class:`pe.data.data.Data` + :return: The data object of the variation of the input synthetic data + :rtype: :py:class:`pe.data.data.Data` + """ + execution_logger.info(f"VARIATION API: creating variations for {len(syn_data.data_frame)} samples") + + samples = syn_data.data_frame[TEXT_DATA_COLUMN_NAME].tolist() + label_ids = syn_data.data_frame[LABEL_ID_COLUMN_NAME].tolist() + + iteration = getattr(syn_data.metadata, "iteration", -1) + blank_probability = self._blank_probabilities[iteration + 1] + + execution_logger.info("VARIATION API: producing LLM requests") + messages_list = [] + requests = [] + generation_args_list = [] + for sample, label_id in zip(samples, label_ids): + variables = {"sample": sample} + variables.update(syn_data.metadata.label_info[label_id].column_values) + generation_args = {} + + if blank_probability is not None: + variables["masked_sample"] = self._blank_sample(sample=sample, blank_probability=blank_probability) + + if self._word_count_std is not None: + word_count = len(sample.split()) + new_word_count = word_count + int(np.random.normal(loc=0, scale=self._word_count_std)) + new_word_count = max(self._min_word_count, new_word_count) + variables["word_count"] = new_word_count + + if self._token_to_word_ratio is not None: + max_completion_tokens = int(new_word_count * self._token_to_word_ratio) + if self._max_completion_tokens_limit is not None: + max_completion_tokens = min(max_completion_tokens, self._max_completion_tokens_limit) + generation_args["max_completion_tokens"] = max_completion_tokens + + messages = self._construct_prompt(prompt_config=self._variation_api_prompt_config, variables=variables) + messages_list.append(messages) + generation_args_list.append(generation_args) + requests.append(Request(messages=messages, generation_args=generation_args)) + execution_logger.info("VARIATION API: getting LLM responses") + responses = self._llm.get_responses(requests) + execution_logger.info("VARIATION API: constructing data") + data_frame = pd.DataFrame( + { + TEXT_DATA_COLUMN_NAME: responses, + LLM_REQUEST_MESSAGES_COLUMN_NAME: [json.dumps(messages) for messages in messages_list], + LLM_PARAMETERS_COLUMN_NAME: [json.dumps(generation_args) for generation_args in generation_args_list], + LABEL_ID_COLUMN_NAME: label_ids, + } + ) + execution_logger.info(f"VARIATION API: finished creating variations for {len(syn_data.data_frame)} samples") + return Data(data_frame=data_frame, metadata=syn_data.metadata) diff --git a/pe/callback/__init__.py b/pe/callback/__init__.py index acdc86e..25bb4ad 100644 --- a/pe/callback/__init__.py +++ b/pe/callback/__init__.py @@ -1,4 +1,8 @@ -from .common.save_checkpoints import SaveCheckpoints -from .common.compute_fid import ComputeFID -from .image.sample_images import SampleImages -from .image.save_all_images import SaveAllImages +from .callback import Callback +from .common import SaveCheckpoints +from .common import ComputeFID +from .image import SampleImages +from .image import SaveAllImages +from .text import SaveTextToCSV + +__all__ = ["Callback", "SaveCheckpoints", "ComputeFID", "SampleImages", "SaveAllImages", "SaveTextToCSV"] diff --git a/pe/callback/common/__init__.py b/pe/callback/common/__init__.py index e69de29..cc2fa7f 100644 --- a/pe/callback/common/__init__.py +++ b/pe/callback/common/__init__.py @@ -0,0 +1,2 @@ +from .compute_fid import ComputeFID +from .save_checkpoints import SaveCheckpoints diff --git a/pe/callback/image/__init__.py b/pe/callback/image/__init__.py index e69de29..59b0a16 100644 --- a/pe/callback/image/__init__.py +++ b/pe/callback/image/__init__.py @@ -0,0 +1,2 @@ +from .sample_images import SampleImages +from .save_all_images import SaveAllImages diff --git a/pe/callback/text/__init__.py b/pe/callback/text/__init__.py new file mode 100644 index 0000000..79fe8e7 --- /dev/null +++ b/pe/callback/text/__init__.py @@ -0,0 +1 @@ +from .save_text_to_csv import SaveTextToCSV diff --git a/pe/callback/text/save_text_to_csv.py b/pe/callback/text/save_text_to_csv.py new file mode 100644 index 0000000..7ece6c0 --- /dev/null +++ b/pe/callback/text/save_text_to_csv.py @@ -0,0 +1,64 @@ +import os +import pandas as pd + +from pe.callback.callback import Callback +from pe.constant.data import LABEL_ID_COLUMN_NAME +from pe.constant.data import TEXT_DATA_COLUMN_NAME +from pe.logging import execution_logger + + +class SaveTextToCSV(Callback): + """The callback that saves the synthetic text to a CSV file.""" + + def __init__( + self, + output_folder, + iteration_format="09d", + ): + """Constructor. + + :param output_folder: The output folder that will be used to save the CSV files + :type output_folder: str + :param iteration_format: The format of the iteration part of the CSV paths, defaults to "09d" + :type iteration_format: str, optional + """ + self._output_folder = output_folder + self._iteration_format = iteration_format + + def _get_csv_path(self, iteration): + """Get the CSV path. + + :param iteration: The PE iteration number + :type iteration: int + :return: The CSV path + :rtype: str + """ + os.makedirs(self._output_folder, exist_ok=True) + iteration_string = format(iteration, self._iteration_format) + csv_path = os.path.join( + self._output_folder, + f"{iteration_string}.csv", + ) + return csv_path + + def __call__(self, syn_data): + """This function is called after each PE iteration that saves the synthetic text to a CSV file. + + :param syn_data: The :py:class:`pe.data.data.Data` object of the synthetic data + :type syn_data: :py:class:`pe.data.data.Data` + """ + execution_logger.info("Saving the synthetic text to a CSV file") + samples = syn_data.data_frame[TEXT_DATA_COLUMN_NAME].tolist() + label_ids = syn_data.data_frame[LABEL_ID_COLUMN_NAME].tolist() + columns = {syn_data.metadata.text_column: samples} + for i in range(len(syn_data.metadata.label_columns)): + column_name = syn_data.metadata.label_columns[i] + columns[column_name] = [ + syn_data.metadata.label_info[label_id].column_values[column_name] for label_id in label_ids + ] + + data_frame = pd.DataFrame(columns) + csv_path = self._get_csv_path(syn_data.metadata.iteration) + data_frame.to_csv(csv_path, index=False) + + execution_logger.info("Finished saving the synthetic text to a CSV file") diff --git a/pe/constant/data.py b/pe/constant/data.py index 0824790..a2ee2ad 100644 --- a/pe/constant/data.py +++ b/pe/constant/data.py @@ -25,5 +25,13 @@ #: The column name of the prompt for the image IMAGE_PROMPT_COLUMN_NAME = "PE.IMAGE_PROMPT" +#: The column name of the text data +TEXT_DATA_COLUMN_NAME = "PE.TEXT" + +#: The column name of the LLM request messages +LLM_REQUEST_MESSAGES_COLUMN_NAME = "PE.LLM.MESSAGES" +#: The column name of the LLM parameters +LLM_PARAMETERS_COLUMN_NAME = "PE.LLM.PARAMETERS" + #: The column name of the nearest neighbors voting IDs HISTOGRAM_NEAREST_NEIGHBORS_VOTING_IDS_COLUMN_NAME = "PE.HISTOGRAM.NEAREST_NEIGHBORS.VOTING_IDS" diff --git a/pe/data/__init__.py b/pe/data/__init__.py index 02a0feb..5e99230 100644 --- a/pe/data/__init__.py +++ b/pe/data/__init__.py @@ -1 +1,5 @@ from .data import Data +from .image import load_image_folder, Cifar10, Camelyon17, Cat +from .text import TextCSV, Yelp, PubMed, OpenReview + +__all__ = ["Data", "load_image_folder", "Cifar10", "Camelyon17", "Cat", "TextCSV", "Yelp", "PubMed", "OpenReview"] diff --git a/pe/data/text/__init__.py b/pe/data/text/__init__.py new file mode 100644 index 0000000..8d91864 --- /dev/null +++ b/pe/data/text/__init__.py @@ -0,0 +1,4 @@ +from .text_csv import TextCSV +from .yelp import Yelp +from .pubmed import PubMed +from .openreview import OpenReview diff --git a/pe/data/text/openreview.py b/pe/data/text/openreview.py new file mode 100644 index 0000000..f459a5f --- /dev/null +++ b/pe/data/text/openreview.py @@ -0,0 +1,87 @@ +import os +import pandas as pd +from collections import namedtuple + +from .text_csv import TextCSV +from pe.util import download +import gdown +import csv + +DownloadInfo = namedtuple("DownloadInfo", ["url", "type"]) + + +class OpenReview(TextCSV): + """The OpenReview dataset in the ICML 2024 Spotlight paper, "Differentially Private Synthetic Data via Foundation + Model APIs 2: Text" (https://arxiv.org/abs/2403.01749).""" + + #: The download information for the OpenReview dataset. + DOWNLOAD_INFO_DICT = { + "train": DownloadInfo( + url=( + "https://raw.githubusercontent.com/AI-secure/aug-pe/bca21c90921bd1151aa7627e676c906165e205a0/data/" + "openreview/iclr23_reviews_train.csv" + ), + type="direct", + ), + "val": DownloadInfo( + url=( + "https://raw.githubusercontent.com/AI-secure/aug-pe/bca21c90921bd1151aa7627e676c906165e205a0/data/" + "openreview/iclr23_reviews_val.csv" + ), + type="direct", + ), + "test": DownloadInfo( + url=( + "https://raw.githubusercontent.com/AI-secure/aug-pe/bca21c90921bd1151aa7627e676c906165e205a0/data/" + "openreview/iclr23_reviews_test.csv" + ), + type="direct", + ), + } + + def __init__(self, root_dir="data", split="train", **kwargs): + """Constructor. + + :param root_dir: The root directory of the dataset. If the dataset is not there, it will be downloaded + automatically. Defaults to "data" + :type root_dir: str, optional + :param split: The split of the dataset. It should be either "train", "val", or "test", defaults to "train" + :type split: str, optional + """ + self._processed_data_path = os.path.join(root_dir, f"{split}_processed.csv") + self._data_path = os.path.join(root_dir, f"{split}.csv") + self._download( + download_info=self.DOWNLOAD_INFO_DICT[split], + data_path=self._data_path, + processed_data_path=self._processed_data_path, + ) + super().__init__( + csv_path=self._processed_data_path, label_columns=["area", "recommendation"], text_column="text", **kwargs + ) + + def _download(self, download_info, data_path, processed_data_path): + """Download the dataset. + + :param download_info: The download information + :type download_info: pe.data.text.openreview.DownloadInfo + :param data_path: The path to the raw data + :type data_path: str + :param processed_data_path: The path to the processed data + :type processed_data_path: str + :raises ValueError: If the download type is unknown + """ + os.makedirs(os.path.dirname(processed_data_path), exist_ok=True) + os.makedirs(os.path.dirname(data_path), exist_ok=True) + if not os.path.exists(processed_data_path): + if not os.path.exists(data_path): + if download_info.type == "gdown": + gdown.download(url=download_info.url, output=data_path) + elif download_info.type == "direct": + download(url=download_info.url, fname=data_path) + else: + raise ValueError(f"Unknown download type: {download_info.type}") + data_frame = pd.read_csv(data_path, dtype=str) + data_frame["label1"] = data_frame["label1"].str.replace("Area: ", "") + data_frame["label2"] = data_frame["label2"].str.replace("Recommendation: ", "") + data_frame = data_frame.rename(columns={"label1": "area", "label2": "recommendation"}) + data_frame.to_csv(processed_data_path, index=False, quoting=csv.QUOTE_ALL) diff --git a/pe/data/text/pubmed.py b/pe/data/text/pubmed.py new file mode 100644 index 0000000..cf81eae --- /dev/null +++ b/pe/data/text/pubmed.py @@ -0,0 +1,66 @@ +import os +from collections import namedtuple + +from .text_csv import TextCSV +from pe.util import download +import gdown + +DownloadInfo = namedtuple("DownloadInfo", ["url", "type"]) + + +class PubMed(TextCSV): + """The PubMed dataset in the ICML 2024 Spotlight paper, "Differentially Private Synthetic Data via Foundation + Model APIs 2: Text" (https://arxiv.org/abs/2403.01749).""" + + #: The download information for the PubMed dataset. + DOWNLOAD_INFO_DICT = { + "train": DownloadInfo(url="https://drive.google.com/uc?id=12-zV93MQNPvM_ORUoahZ2n4odkkOXD-r", type="gdown"), + "val": DownloadInfo( + url=( + "https://raw.githubusercontent.com/AI-secure/aug-pe/bca21c90921bd1151aa7627e676c906165e205a0/" + "data/pubmed/dev.csv" + ), + type="direct", + ), + "test": DownloadInfo( + url=( + "https://raw.githubusercontent.com/AI-secure/aug-pe/bca21c90921bd1151aa7627e676c906165e205a0/" + "data/pubmed/test.csv" + ), + type="direct", + ), + } + + def __init__(self, root_dir="data", split="train", **kwargs): + """Constructor. + + :param root_dir: The root directory of the dataset. If the dataset is not there, it will be downloaded + automatically. Defaults to "data" + :type root_dir: str, optional + :param split: The split of the dataset. It should be either "train", "val", or "test", defaults to "train" + :type split: str, optional + """ + self._data_path = os.path.join(root_dir, f"{split}.csv") + self._download( + download_info=self.DOWNLOAD_INFO_DICT[split], + data_path=self._data_path, + ) + super().__init__(csv_path=self._data_path, label_columns=[], text_column="text", **kwargs) + + def _download(self, download_info, data_path): + """Download the dataset. + + :param download_info: The download information + :type download_info: pe.data.text.pubmed.DownloadInfo + :param data_path: The path to the raw data + :type data_path: str + :raises ValueError: If the download type is unknown + """ + os.makedirs(os.path.dirname(data_path), exist_ok=True) + if not os.path.exists(data_path): + if download_info.type == "gdown": + gdown.download(url=download_info.url, output=data_path) + elif download_info.type == "direct": + download(url=download_info.url, fname=data_path) + else: + raise ValueError(f"Unknown download type: {download_info.type}") diff --git a/pe/data/text/text_csv.py b/pe/data/text/text_csv.py new file mode 100644 index 0000000..a9bf14e --- /dev/null +++ b/pe/data/text/text_csv.py @@ -0,0 +1,44 @@ +from pe.data import Data +import pandas as pd +from pe.constant.data import LABEL_ID_COLUMN_NAME +from pe.constant.data import TEXT_DATA_COLUMN_NAME + + +class TextCSV(Data): + """The text dataset in CSV format.""" + + def __init__(self, csv_path, label_columns=[], text_column="text", num_samples=None): + """Constructor. + + :param csv_path: The path to the CSV file + :type csv_path: str + :param label_columns: The names of the columns that contain the labels, defaults to [] + :type label_columns: list, optional + :param text_column: The name of the column that contains the text data, defaults to "text" + :type text_column: str, optional + :param num_samples: The number of samples to load from the CSV file. If None, load all samples. Defaults to + None + :type num_samples: int, optional + :raises ValueError: If the label columns or text column does not exist in the CSV file + """ + data_frame = pd.read_csv(csv_path, dtype=str) + if num_samples is not None: + data_frame = data_frame[:num_samples] + for column in label_columns + [text_column]: + if column not in data_frame.columns: + raise ValueError(f"Column {column} does not exist in the CSV file") + labels = data_frame.apply(lambda row: tuple([row[col] for col in label_columns]), axis=1).tolist() + label_set = list(sorted(set(labels))) + label_id_map = {label: i for i, label in enumerate(label_set)} + label_ids = [label_id_map[label] for label in labels] + data_frame[LABEL_ID_COLUMN_NAME] = label_ids + label_info = [ + { + "name": " | ".join(f"{label_columns[i]}: {label[i]}" for i in range(len(label_columns))), + "column_values": {label_columns[i]: label[i] for i in range(len(label_columns))}, + } + for label in label_set + ] + metadata = {"label_columns": label_columns, "text_column": text_column, "label_info": label_info} + data_frame = data_frame.rename(columns={text_column: TEXT_DATA_COLUMN_NAME}) + super().__init__(data_frame=data_frame, metadata=metadata) diff --git a/pe/data/text/yelp.py b/pe/data/text/yelp.py new file mode 100644 index 0000000..e1e3540 --- /dev/null +++ b/pe/data/text/yelp.py @@ -0,0 +1,84 @@ +import os +import pandas as pd +from collections import namedtuple + +from .text_csv import TextCSV +from pe.util import download +import gdown +import csv + +DownloadInfo = namedtuple("DownloadInfo", ["url", "type"]) + + +class Yelp(TextCSV): + """The Yelp dataset in the ICML 2024 Spotlight paper, "Differentially Private Synthetic Data via Foundation + Model APIs 2: Text" (https://arxiv.org/abs/2403.01749).""" + + #: The download information for the Yelp dataset. + DOWNLOAD_INFO_DICT = { + "train": DownloadInfo(url="https://drive.google.com/uc?id=1epLuBxCk5MGnm1GiIfLcTcr-tKgjCrc2", type="gdown"), + "val": DownloadInfo( + url=( + "https://raw.githubusercontent.com/AI-secure/aug-pe/bca21c90921bd1151aa7627e676c906165e205a0/data/" + "yelp/dev.csv" + ), + type="direct", + ), + "test": DownloadInfo( + url=( + "https://raw.githubusercontent.com/AI-secure/aug-pe/bca21c90921bd1151aa7627e676c906165e205a0/data/" + "yelp/test.csv" + ), + type="direct", + ), + } + + def __init__(self, root_dir="data", split="train", **kwargs): + """Constructor. + + :param root_dir: The root directory of the dataset. If the dataset is not there, it will be downloaded + automatically. Defaults to "data" + :type root_dir: str, optional + :param split: The split of the dataset. It should be either "train", "val", or "test", defaults to "train" + :type split: str, optional + """ + self._processed_data_path = os.path.join(root_dir, f"{split}_processed.csv") + self._data_path = os.path.join(root_dir, f"{split}.csv") + self._download( + download_info=self.DOWNLOAD_INFO_DICT[split], + data_path=self._data_path, + processed_data_path=self._processed_data_path, + ) + super().__init__( + csv_path=self._processed_data_path, + label_columns=["business_category", "review_stars"], + text_column="text", + **kwargs, + ) + + def _download(self, download_info, data_path, processed_data_path): + """Download the dataset. + + :param download_info: The download information + :type download_info: pe.data.text.yelp.DownloadInfo + :param data_path: The path to the raw data + :type data_path: str + :param processed_data_path: The path to the processed data + :type processed_data_path: str + :raises ValueError: If the download type is unknown + """ + os.makedirs(os.path.dirname(processed_data_path), exist_ok=True) + os.makedirs(os.path.dirname(data_path), exist_ok=True) + if not os.path.exists(processed_data_path): + if not os.path.exists(data_path): + if download_info.type == "gdown": + gdown.download(url=download_info.url, output=data_path) + elif download_info.type == "direct": + download(url=download_info.url, fname=data_path) + else: + raise ValueError(f"Unknown download type: {download_info.type}") + data_frame = pd.read_csv(data_path, dtype=str) + data_frame["label1"] = data_frame["label1"].str.replace("Business Category: ", "") + data_frame["label2"] = data_frame["label2"].str.replace("Review Stars: ", "") + data_frame = data_frame.rename(columns={"label1": "business_category", "label2": "review_stars"}) + data_frame.to_csv(processed_data_path, index=False, quoting=csv.QUOTE_ALL) diff --git a/pe/dp/__init__.py b/pe/dp/__init__.py index 5cbddc4..ea2fccc 100644 --- a/pe/dp/__init__.py +++ b/pe/dp/__init__.py @@ -1,2 +1,4 @@ from .dp import DP from .gaussian import Gaussian + +__all__ = ["DP", "Gaussian"] diff --git a/pe/embedding/__init__.py b/pe/embedding/__init__.py index b42f9d1..81d0627 100644 --- a/pe/embedding/__init__.py +++ b/pe/embedding/__init__.py @@ -1 +1,5 @@ from .embedding import Embedding +from .image import Inception +from .text import SentenceTransformer + +__all__ = ["Embedding", "Inception", "SentenceTransformer"] diff --git a/pe/embedding/text/__init__.py b/pe/embedding/text/__init__.py new file mode 100644 index 0000000..9da894b --- /dev/null +++ b/pe/embedding/text/__init__.py @@ -0,0 +1 @@ +from .sentence_transformer import SentenceTransformer diff --git a/pe/embedding/text/sentence_transformer.py b/pe/embedding/text/sentence_transformer.py new file mode 100644 index 0000000..0202487 --- /dev/null +++ b/pe/embedding/text/sentence_transformer.py @@ -0,0 +1,56 @@ +import pandas as pd +from sentence_transformers import SentenceTransformer as ST + +from pe.embedding import Embedding +from pe.logging import execution_logger +from pe.constant.data import TEXT_DATA_COLUMN_NAME +from pe.constant.data import EMBEDDING_COLUMN_NAME + + +class SentenceTransformer(Embedding): + """Compute the Sentence Transformers embedding of text.""" + + def __init__(self, model, batch_size=2000): + """Constructor. + + :param model: The Sentence Transformers model to use + :type model: str + :param batch_size: The batch size to use for computing the embedding, defaults to 2000 + :type batch_size: int, optional + """ + super().__init__() + self._model_name = model + self._model = ST(model) + self._batch_size = batch_size + + @property + def column_name(self): + """The column name to be used in the data frame.""" + return f"{EMBEDDING_COLUMN_NAME}.{type(self).__name__}.{self._model_name}" + + def compute_embedding(self, data): + """Compute the Sentence Transformers embedding of text. + + :param data: The data object containing the text + :type data: :py:class:`pe.data.data.Data` + :return: The data object with the computed embedding + :rtype: :py:class:`pe.data.data.Data` + """ + uncomputed_data = self.filter_uncomputed_rows(data) + if len(uncomputed_data.data_frame) == 0: + execution_logger.info(f"Embedding: {self.column_name} already computed") + return data + execution_logger.info( + f"Embedding: computing {self.column_name} for {len(uncomputed_data.data_frame)}/{len(data.data_frame)}" + " samples" + ) + samples = uncomputed_data.data_frame[TEXT_DATA_COLUMN_NAME].tolist() + embeddings = self._model.encode(samples, batch_size=self._batch_size) + uncomputed_data.data_frame[self.column_name] = pd.Series( + list(embeddings), index=uncomputed_data.data_frame.index + ) + execution_logger.info( + f"Embedding: finished computing {self.column_name} for " + f"{len(uncomputed_data.data_frame)}/{len(data.data_frame)} samples" + ) + return self.merge_computed_rows(data, uncomputed_data) diff --git a/pe/histogram/__init__.py b/pe/histogram/__init__.py index 7db90c1..d25317a 100644 --- a/pe/histogram/__init__.py +++ b/pe/histogram/__init__.py @@ -1,2 +1,4 @@ from .histogram import Histogram from .nearest_neighbors import NearestNeighbors + +__all__ = ["Histogram", "NearestNeighbors"] diff --git a/pe/llm/__init__.py b/pe/llm/__init__.py new file mode 100644 index 0000000..f148076 --- /dev/null +++ b/pe/llm/__init__.py @@ -0,0 +1,7 @@ +from .llm import LLM +from .request import Request +from .openai import OpenAILLM +from .azure_openai import AzureOpenAILLM +from .huggingface.huggingface import HuggingfaceLLM + +__all__ = ["LLM", "Request", "OpenAILLM", "AzureOpenAILLM", "HuggingfaceLLM"] diff --git a/pe/llm/azure_openai.py b/pe/llm/azure_openai.py new file mode 100644 index 0000000..2959390 --- /dev/null +++ b/pe/llm/azure_openai.py @@ -0,0 +1,145 @@ +from openai import AzureOpenAI +from openai import BadRequestError +from openai import AuthenticationError +from openai import NotFoundError +from openai import PermissionDeniedError +from azure.identity import AzureCliCredential, get_bearer_token_provider +import os +from tenacity import retry +from tenacity import retry_if_not_exception_type +from tenacity import stop_after_attempt +from tenacity import wait_random_exponential +from tenacity import before_sleep_log +import json +import logging +from concurrent.futures import ThreadPoolExecutor +import random + +from pe.logging import execution_logger +from .llm import LLM + + +class AzureOpenAILLM(LLM): + """A wrapper for Azure OpenAI LLM APIs. The following environment variables are required: + + * ``AZURE_OPENAI_API_KEY``: Azure OpenAI API key. You can get it from https://portal.azure.com/. Multiple keys can + be separated by commas, and a key will be selected randomly for each request. The key can also be "AZ_CLI", in + which case the Azure CLI will be used to authenticate the requests, and the environment variable + ``AZURE_OPENAI_API_SCOPE`` needs to be set. See Azure OpenAI authentication documentation for more information: + https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints#microsoft-entra-id-authentication + * ``AZURE_OPENAI_ENDPOINT``: Azure OpenAI endpoint. You can get it from https://portal.azure.com/. + * ``AZURE_OPENAI_API_VERSION``: Azure OpenAI API version. You can get it from https://portal.azure.com/.""" + + def __init__(self, dry_run=False, num_threads=1, **generation_args): + """Constructor. + + :param dry_run: Whether to enable dry run. When dry run is enabled, the responses are fake and the APIs are + not called. Defaults to False + :type dry_run: bool, optional + :param num_threads: The number of threads to use for making concurrent API calls, defaults to 1 + :type num_threads: int, optional + :param \\*\\*generation_args: The generation arguments that will be passed to the OpenAI API + :type \\*\\*generation_args: str + """ + self._dry_run = dry_run + self._num_threads = num_threads + self._generation_args = generation_args + + self._api_keys = self._get_environment_variable("AZURE_OPENAI_API_KEY").split(",") + self._clients = [] + for api_key in self._api_keys: + if api_key == "AZ_CLI": + credential = get_bearer_token_provider( + AzureCliCredential(), self._get_environment_variable("AZURE_OPENAI_API_SCOPE") + ) + client = AzureOpenAI( + azure_ad_token_provider=credential, + api_version=self._get_environment_variable("AZURE_OPENAI_API_VERSION"), + azure_endpoint=self._get_environment_variable("AZURE_OPENAI_API_ENDPOINT"), + ) + else: + + client = AzureOpenAI( + api_key=self._get_environment_variable("AZURE_OPENAI_API_KEY"), + api_version=self._get_environment_variable("AZURE_OPENAI_API_VERSION"), + azure_endpoint=self._get_environment_variable("AZURE_OPENAI_API_ENDPOINT"), + ) + self._clients.append(client) + execution_logger.info(f"Using {len(self._api_keys)} AzureOpenAI API keys") + + @property + def generation_arg_map(self): + """Get the mapping from the generation arguments to arguments for this specific LLM. + + :return: The mapping that maps ``max_completion_tokens`` to ``max_tokens`` + :rtype: dict + """ + return {"max_completion_tokens": "max_tokens"} + + def _get_environment_variable(self, name): + """Get the environment variable. + + :param name: The name of the environment variable + :type name: str + :raises ValueError: If the environment variable is not set + :return: The value of the environment variable + :rtype: str + """ + if name not in os.environ or os.environ[name] == "": + raise ValueError(f"{name} environment variable is not set.") + return os.environ[name] + + def get_responses(self, requests, **generation_args): + """Get the responses from the LLM. + + :param requests: The requests + :type requests: list[:py:class:`pe.llm.request.Request`] + :param \\*\\*generation_args: The generation arguments. The priority of the generation arguments from the + highest to the lowerest is in the order of: the arguments set in the requests > the arguments passed to + this function > and the arguments passed to the constructor + :type \\*\\*generation_args: str + :return: The responses + :rtype: list[str] + """ + messages_list = [request.messages for request in requests] + generation_args_list = [ + self.get_generation_args(self._generation_args, generation_args, request.generation_args) + for request in requests + ] + with ThreadPoolExecutor(max_workers=self._num_threads) as executor: + responses = list(executor.map(self._get_response_for_one_request, messages_list, generation_args_list)) + return responses + + @retry( + retry=retry_if_not_exception_type( + ( + BadRequestError, + AuthenticationError, + NotFoundError, + PermissionDeniedError, + ) + ), + wait=wait_random_exponential(min=8, max=500), + stop=stop_after_attempt(30), + before_sleep=before_sleep_log(execution_logger, logging.DEBUG), + ) + def _get_response_for_one_request(self, messages, generation_args): + """Get the response for one request. + + :param messages: The messages + :type messages: list[str] + :param generation_args: The generation arguments + :type generation_args: dict + :return: The response + :rtype: str + """ + if self._dry_run: + response = f"Dry run enabled. The request is {json.dumps(messages)}" + else: + client = random.choice(self._clients) + full_response = client.chat.completions.create( + messages=messages, + **generation_args, + ) + response = full_response.choices[0].message.content + return response diff --git a/pe/llm/huggingface/__init__.py b/pe/llm/huggingface/__init__.py new file mode 100644 index 0000000..6e2f6f1 --- /dev/null +++ b/pe/llm/huggingface/__init__.py @@ -0,0 +1,3 @@ +from .register_fastchat.gpt2 import register as register_gpt2 + +register_gpt2() diff --git a/pe/llm/huggingface/huggingface.py b/pe/llm/huggingface/huggingface.py new file mode 100644 index 0000000..e8420c6 --- /dev/null +++ b/pe/llm/huggingface/huggingface.py @@ -0,0 +1,162 @@ +import torch +import transformers +from pe.logging import execution_logger +from fastchat.model.model_adapter import get_conversation_template + +from ..llm import LLM + + +class HuggingfaceLLM(LLM): + """A wrapper for Huggingface LLMs.""" + + def __init__(self, model_name_or_path, batch_size=128, dry_run=False, **generation_args): + """Constructor. + + :param model_name_or_path: The model name or path of the Huggingface model + :type model_name_or_path: str + :param batch_size: The batch size to use for generating the responses, defaults to 128 + :type batch_size: int, optional + :param dry_run: Whether to enable dry run. When dry run is enabled, the responses are fake and the LLMs are + not called. Defaults to False + :type dry_run: bool, optional + :param \\*\\*generation_args: The generation arguments that will be passed to the OpenAI API + :type \\*\\*generation_args: str + """ + self._dry_run = dry_run + self._generation_args = generation_args + + self._model_name_or_path = model_name_or_path + self._batch_size = batch_size + + self._tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path, device_map="auto") + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token + self._tokenizer.padding_side = "left" + + self._model = transformers.AutoModelForCausalLM.from_pretrained( + model_name_or_path, device_map="auto", torch_dtype=torch.float16 + ) + if self._model.config.pad_token_id is None: + self._model.config.pad_token_id = self._model.config.eos_token_id + self._model.eval() + + self._conv_template = self._get_conv_template() + self._stop_str = self._conv_template.stop_str + self._stop_token_ids = self._conv_template.stop_token_ids or [] + self._stop_token_ids.append(self._tokenizer.eos_token_id) + + @property + def generation_arg_map(self): + """Get the mapping from the generation arguments to arguments for this specific LLM. + + :return: The mapping that maps ``max_completion_tokens`` to ``max_new_tokens`` + :rtype: dict + """ + return {"max_completion_tokens": "max_new_tokens"} + + def _get_conv_template(self): + """Get the conversation template. + + :return: The empty conversation template for this model from FastChat + :rtype: :py:class:`fastchat.conversation.Conversation` + """ + template = get_conversation_template(self._model_name_or_path) + template.messages = [] + template.system_message = "" + return template + + def _get_prompt(self, messages): + """Get the prompt from the messages. + + :param messages: The messages + :type messages: list[dict] + :raises ValueError: If the role is invalid + :return: The prompt + :rtype: str + """ + template = self._conv_template.copy() + for message in messages: + if message["role"] == "system": + template.set_system_message(message["content"]) + elif message["role"] == "user": + template.append_message(role=template.roles[0], message=message["content"]) + elif message["role"] == "assistant": + template.append_message(role=template.roles[1], message=message["content"]) + else: + raise ValueError(f"Invalid role: {message['role']}") + template.append_message(role=template.roles[1], message=None) + return template.get_prompt() + + def get_responses(self, requests, **generation_args): + """Get the responses from the LLM. + + :param requests: The requests + :type requests: list[:py:class:`pe.llm.request.Request`] + :param \\*\\*generation_args: The generation arguments. The priority of the generation arguments from the + highest to the lowerest is in the order of: the arguments set in the requests > the arguments passed to + this function > and the arguments passed to the constructor + :type \\*\\*generation_args: str + :return: The responses + :rtype: list[str] + """ + execution_logger.info("HuggingfaceLLM: producing prompts") + prompt_list = [] + generation_args_list = [] + for request in requests: + prompt_list.append(self._get_prompt(request.messages)) + generation_args_list.append( + self.get_generation_args(self._generation_args, generation_args, request.generation_args) + ) + execution_logger.info("HuggingfaceLLM: getting responses") + responses = [None] * len(requests) + # Group requests according to generation_args + generation_args_fronzen_set_list = [ + frozenset(generation_args.items()) for generation_args in generation_args_list + ] + generation_args_set = list(set(generation_args_fronzen_set_list)) + generation_args_to_set_index = {g: i for i, g in enumerate(generation_args_set)} + grouped_request_indices = [[] for i in range(len(generation_args_set))] + for i, generation_args in enumerate(generation_args_fronzen_set_list): + grouped_request_indices[generation_args_to_set_index[generation_args]].append(i) + for group in grouped_request_indices: + sub_prompt_list = [prompt_list[j] for j in group] + sub_response_list = self._get_responses(sub_prompt_list, generation_args_list[group[0]]) + for i, j in enumerate(group): + responses[j] = sub_response_list[i] + assert None not in responses + return responses + + @torch.no_grad + def _get_responses(self, prompt_list, generation_args): + """Get the responses from the LLM. + + :param prompt_list: The prompts + :type prompt_list: list[str] + :param generation_args: The generation arguments + :type generation_args: dict + :return: The responses + :rtype: list[str] + """ + if self._dry_run: + responses = [f"Dry run enabled. The request is {prompt}" for prompt in prompt_list] + else: + input_ids = self._tokenizer( + prompt_list, return_tensors="pt", padding=True, padding_side="left" + ).input_ids.to(self._model.device) + responses = [] + for i in range(0, len(input_ids), self._batch_size): + batch_input_ids = input_ids[i : i + self._batch_size] + batch_responses = self._model.generate( + batch_input_ids, + stop_strings=self._stop_str, + eos_token_id=self._stop_token_ids, + do_sample=True, + **generation_args, + ) + batch_responses = self._tokenizer.batch_decode( + batch_responses[:, input_ids.shape[1] :], + clean_up_tokenization_spaces=True, + skip_special_tokens=True, + ) + responses.extend(batch_responses) + return responses diff --git a/pe/llm/huggingface/register_fastchat/__init__.py b/pe/llm/huggingface/register_fastchat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pe/llm/huggingface/register_fastchat/gpt2.py b/pe/llm/huggingface/register_fastchat/gpt2.py new file mode 100644 index 0000000..36bafaa --- /dev/null +++ b/pe/llm/huggingface/register_fastchat/gpt2.py @@ -0,0 +1,33 @@ +from fastchat.model.model_adapter import BaseModelAdapter +from fastchat.model.model_adapter import register_model_adapter +from fastchat.conversation import get_conv_template +from fastchat.conversation import register_conv_template +from fastchat.conversation import Conversation +from fastchat.conversation import SeparatorStyle + + +class GPT2Adapter(BaseModelAdapter): + """The GPT-2 model adapter for fastchat.""" + + def match(self, model_path): + return "gpt2" in model_path.lower() + + def load_model(self, model_path, from_pretrained_kwargs): + raise NotImplementedError + + def get_default_conv_template(self, model_path): + return get_conv_template("gpt2") + + +def register(): + """Register the GPT-2 model adapter for fastchat.""" + register_conv_template( + Conversation( + name="gpt2", + system_message="", + roles=("", ""), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + ) + ) + register_model_adapter(GPT2Adapter) diff --git a/pe/llm/llm.py b/pe/llm/llm.py new file mode 100644 index 0000000..6cf4d5e --- /dev/null +++ b/pe/llm/llm.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from functools import reduce + + +class LLM(ABC): + """The abstract class for large language models (LLMs).""" + + @abstractmethod + def get_responses(self, requests, **generation_args): + """Get the responses from the LLM. + + :param requests: The requests + :type requests: list[:py:class:`pe.llm.request.Request`] + :param \\*\\*generation_args: The generation arguments + :type \\*\\*generation_args: str + :return: The responses + :rtype: list[str] + """ + ... + + @property + def generation_arg_map(self): + """Get the mapping from the generation arguments to arguments for this specific LLM. + + :return: The mapping from the generation arguments to the large language model arguments + :rtype: dict + """ + return {} + + def get_generation_args(self, *args): + """Get the generation arguments from a list of dictionaries. + + :param \\*args: A list of generation arguments. The later ones will overwrite the earlier ones. + :type \\*args: dict + :return: The generation arguments + :rtype: dict + """ + generation_args = reduce(lambda x, y: {**x, **y}, args) + generation_args = { + k if k not in self.generation_arg_map else self.generation_arg_map[k]: v + for k, v in generation_args.items() + } + return generation_args diff --git a/pe/llm/openai.py b/pe/llm/openai.py new file mode 100644 index 0000000..49900fc --- /dev/null +++ b/pe/llm/openai.py @@ -0,0 +1,112 @@ +from openai import OpenAI +from openai import BadRequestError +from openai import AuthenticationError +from openai import NotFoundError +from openai import PermissionDeniedError +import os +from tenacity import retry +from tenacity import retry_if_not_exception_type +from tenacity import stop_after_attempt +from tenacity import wait_random_exponential +from tenacity import before_sleep_log +import json +import logging +from concurrent.futures import ThreadPoolExecutor +import random + +from pe.logging import execution_logger +from .llm import LLM + + +class OpenAILLM(LLM): + """A wrapper for OpenAI LLM APIs. The following environment variables are required: + + * ``OPENAI_API_KEY``: OpenAI API key. You can get it from https://platform.openai.com/account/api-keys. Multiple + keys can be separated by commas, and a key will be selected randomly for each request.""" + + def __init__(self, dry_run=False, num_threads=1, **generation_args): + """Constructor. + + :param dry_run: Whether to enable dry run. When dry run is enabled, the responses are fake and the APIs are + not called. Defaults to False + :type dry_run: bool, optional + :param num_threads: The number of threads to use for making concurrent API calls, defaults to 1 + :type num_threads: int, optional + :param \\*\\*generation_args: The generation arguments that will be passed to the OpenAI API + :type \\*\\*generation_args: str + """ + self._dry_run = dry_run + self._num_threads = num_threads + self._generation_args = generation_args + + self._api_keys = self._get_environment_variable("OPENAI_API_KEY").split(",") + self._clients = [OpenAI(api_key=api_key) for api_key in self._api_keys] + execution_logger.info(f"Using {len(self._api_keys)} OpenAI API keys") + + def _get_environment_variable(self, name): + """Get the environment variable. + + :param name: The name of the environment variable + :type name: str + :raises ValueError: If the environment variable is not set + :return: The value of the environment variable + :rtype: str + """ + if name not in os.environ or os.environ[name] == "": + raise ValueError(f"{name} environment variable is not set.") + return os.environ[name] + + def get_responses(self, requests, **generation_args): + """Get the responses from the LLM. + + :param requests: The requests + :type requests: list[:py:class:`pe.llm.request.Request`] + :param \\*\\*generation_args: The generation arguments. The priority of the generation arguments from the + highest to the lowerest is in the order of: the arguments set in the requests > the arguments passed to + this function > and the arguments passed to the constructor + :type \\*\\*generation_args: str + :return: The responses + :rtype: list[str] + """ + messages_list = [request.messages for request in requests] + generation_args_list = [ + self.get_generation_args(self._generation_args, generation_args, request.generation_args) + for request in requests + ] + with ThreadPoolExecutor(max_workers=self._num_threads) as executor: + responses = list(executor.map(self._get_response_for_one_request, messages_list, generation_args_list)) + return responses + + @retry( + retry=retry_if_not_exception_type( + ( + BadRequestError, + AuthenticationError, + NotFoundError, + PermissionDeniedError, + ) + ), + wait=wait_random_exponential(min=8, max=500), + stop=stop_after_attempt(30), + before_sleep=before_sleep_log(execution_logger, logging.DEBUG), + ) + def _get_response_for_one_request(self, messages, generation_args): + """Get the response for one request. + + :param messages: The messages + :type messages: list[str] + :param generation_args: The generation arguments + :type generation_args: dict + :return: The response + :rtype: str + """ + if self._dry_run: + response = f"Dry run enabled. The request is {json.dumps(messages)}" + else: + client = random.choice(self._clients) + full_response = client.chat.completions.create( + messages=messages, + **generation_args, + ) + response = full_response.choices[0].message.content + return response diff --git a/pe/llm/request.py b/pe/llm/request.py new file mode 100644 index 0000000..72ce151 --- /dev/null +++ b/pe/llm/request.py @@ -0,0 +1,11 @@ +from collections import namedtuple + + +Request = namedtuple("Request", ["messages", "generation_args"], defaults=[[], {}]) +""" The request to the LLM. + +:param messages: The messages to the LLM +:type messages: list[dict] +:param generation_args: The generation arguments to the LLM +:type generation_args: dict +""" diff --git a/pe/logger/__init__.py b/pe/logger/__init__.py index 0447e54..4e92098 100644 --- a/pe/logger/__init__.py +++ b/pe/logger/__init__.py @@ -1,4 +1,7 @@ +from .logger import Logger from .csv_print import CSVPrint from .image_file import ImageFile from .log_print import LogPrint from .matplotlib_pdf import MatplotlibPDF + +__all__ = ["Logger", "CSVPrint", "ImageFile", "LogPrint", "MatplotlibPDF"] diff --git a/pe/population/__init__.py b/pe/population/__init__.py index 1a5cb0c..05dacff 100644 --- a/pe/population/__init__.py +++ b/pe/population/__init__.py @@ -1,2 +1,4 @@ from .population import Population from .pe_population import PEPopulation + +__all__ = ["Population", "PEPopulation"] diff --git a/pe/population/pe_population.py b/pe/population/pe_population.py index ef4a174..686b664 100644 --- a/pe/population/pe_population.py +++ b/pe/population/pe_population.py @@ -15,7 +15,7 @@ class PEPopulation(Population): def __init__( self, api, - histogram_threshold, + histogram_threshold=None, initial_variation_api_fold=0, next_variation_api_fold=1, keep_selected=False, @@ -25,8 +25,8 @@ def __init__( :param api: The API object that contains the random and variation APIs :type api: :py:class:`pe.api.api.API` - :param histogram_threshold: The threshold for clipping the histogram - :type histogram_threshold: float + :param histogram_threshold: The threshold for clipping the histogram. None means no clipping. Defaults to None + :type histogram_threshold: float, optional :param initial_variation_api_fold: The number of variations to apply to the initial synthetic data, defaults to 0 :type initial_variation_api_fold: int, optional @@ -34,8 +34,9 @@ def __init__( :type next_variation_api_fold: int, optional :param keep_selected: Whether to keep the selected data in the next synthetic data, defaults to False :type keep_selected: bool, optional - :param selection_mode: The selection mode for selecting the data. It should be one of the following: "sample"( - random sampling proportional to the histogram). Defaults to "sample" + :param selection_mode: The selection mode for selecting the data. It should be one of the following: "sample" ( + random sampling proportional to the histogram), "rank" (select the top samples according to the histogram). + Defaults to "sample" :type selection_mode: str, optional :raises ValueError: If next_variation_api_fold is 0 and keep_selected is False """ @@ -88,8 +89,11 @@ def _post_process_histogram(self, syn_data): :rtype: :py:class:`pe.data.data.Data` """ count = syn_data.data_frame[DP_HISTOGRAM_COLUMN_NAME].to_numpy() - clipped_count = np.clip(count, a_min=self._histogram_threshold, a_max=None) - clipped_count -= self._histogram_threshold + if self._histogram_threshold is not None: + clipped_count = np.clip(count, a_min=self._histogram_threshold, a_max=None) + clipped_count -= self._histogram_threshold + else: + clipped_count = count syn_data.data_frame[POST_PROCESSED_DP_HISTOGRAM_COLUMN_NAME] = clipped_count return syn_data @@ -111,6 +115,12 @@ def _select_data(self, syn_data, num_samples): new_data_frame = syn_data.data_frame.iloc[indices] new_data_frame[PARENT_SYN_DATA_INDEX_COLUMN_NAME] = syn_data.data_frame.index[indices] return Data(data_frame=new_data_frame, metadata=syn_data.metadata) + elif self._selection_mode == "rank": + count = syn_data.data_frame[POST_PROCESSED_DP_HISTOGRAM_COLUMN_NAME].to_numpy() + indices = np.argsort(count)[::-1][:num_samples] + new_data_frame = syn_data.data_frame.iloc[indices] + new_data_frame[PARENT_SYN_DATA_INDEX_COLUMN_NAME] = syn_data.data_frame.index[indices] + return Data(data_frame=new_data_frame, metadata=syn_data.metadata) else: raise ValueError(f"Selection mode {self._selection_mode} is not supported") diff --git a/pe/runner/__init__.py b/pe/runner/__init__.py index 5d3b4ba..73496a8 100644 --- a/pe/runner/__init__.py +++ b/pe/runner/__init__.py @@ -1 +1,3 @@ from .pe import PE + +__all__ = ["PE"] diff --git a/pyproject.toml b/pyproject.toml index 93cebb6..eac8f57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,13 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["pre-commit", "black", "sphinx", "sphinx-rtd-theme"] +dev = [ + "pre-commit", + "black", + "sphinx", + "sphinx-rtd-theme", + "sphinx-toolbox", +] image = [ "blobfile", "torch", @@ -32,6 +38,20 @@ image = [ "improved-diffusion@git+https://github.com/fjxmlzn/improved-diffusion.git@8f6677c3c47d1c1ad2e22ad2603eaec4cc639805", "wilds", ] +text = [ + "gdown", + "openai", + "tenacity", + "azure-identity", + "tiktoken", + "python-dotenv", + "sentence-transformers", + "protobuf", + "sentencepiece", + "fschat", + "transformers", + "accelerate", +] [project.urls] Homepage = "https://microsoft.github.io/DPSDA/" @@ -39,5 +59,17 @@ Documentation = "https://microsoft.github.io/DPSDA/" Repository = "https://github.com/microsoft/DPSDA" "Bug Tracker" = "https://github.com/microsoft/DPSDA/issues" +[tool.setuptools] +include-package-data = false + [tool.setuptools.packages.find] -exclude = ["doc", "data", "example", "docker*", "amlt", "dist*", "_*", "result*", "*.egg-info", "__pycache__", ".git*"] +exclude = [ + "doc*", + "data*", + "example*", + "docker*", + "amlt*", + "dist*", + "_*", + "result*", +]