From 39d282f42b6ae8719745f3e3ff10c0471f1df220 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 16 Oct 2024 09:31:52 -0700 Subject: [PATCH] a new paper claims there is a free lunch by setting model weights to ema weights every epoch. allow researchers to experiment with this, conveniently already available in EMA-pytorch due to hare and tortoise paper --- README.md | 11 +++++++++++ alphafold3_pytorch/trainer.py | 2 ++ pyproject.toml | 4 ++-- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7fdde8bc..6dad2e58 100644 --- a/README.md +++ b/README.md @@ -483,3 +483,14 @@ docker run -v .:/data --gpus all -it af3 journal = {bioRxiv} } ``` + +```bibtex +@article{Li2024SwitchEA, + title = {Switch EMA: A Free Lunch for Better Flatness and Sharpness}, + author = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li}, + journal = {ArXiv}, + year = {2024}, + volume = {abs/2402.09240}, + url = {https://api.semanticscholar.org/CorpusID:267657558} +} +``` diff --git a/alphafold3_pytorch/trainer.py b/alphafold3_pytorch/trainer.py index 9da038d8..f469f43f 100644 --- a/alphafold3_pytorch/trainer.py +++ b/alphafold3_pytorch/trainer.py @@ -178,6 +178,7 @@ def __init__( use_foreach = True ), ema_on_cpu = False, + ema_update_model_with_ema_every: int | None = None, use_adam_atan2: bool = False, use_lion: bool = False, use_torch_compile: bool = False @@ -220,6 +221,7 @@ def __init__( include_online_model = False, allow_different_devices = True, coerce_dtype = True, + update_model_with_ema_every = ema_update_model_with_ema_every, **ema_kwargs ) diff --git a/pyproject.toml b/pyproject.toml index 8bea0da7..7853fee7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.6.2" +version = "0.6.3" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }, @@ -33,7 +33,7 @@ dependencies = [ "CoLT5-attention>=0.11.0", "einops>=0.8.0", "einx>=0.2.2", - "ema-pytorch>=0.6.4", + "ema-pytorch>=0.7.0", "environs", "lion-pytorch>=0.2.2", "joblib",