From 4579d5e90b3decb9aa97dd78d6b5cf453f50e5f9 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Fri, 15 Aug 2025 13:50:23 +0200 Subject: [PATCH] updated references to include URLs --- torchcfm/conditional_flow_matching.py | 45 +++++++++++++++------------ 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/torchcfm/conditional_flow_matching.py b/torchcfm/conditional_flow_matching.py index b71b6bd2..e34a2613 100644 --- a/torchcfm/conditional_flow_matching.py +++ b/torchcfm/conditional_flow_matching.py @@ -47,6 +47,8 @@ class ConditionalFlowMatcher: - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function - conditional flow matching ut(x1|x0) = x1 - x0 - score function $\nabla log p_t(x|x0, x1)$ + + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ def __init__(self, sigma: Union[float, int] = 0.0): @@ -76,7 +78,7 @@ def compute_mu_t(self, x0, x1, t): References ---------- - [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ t = pad_t_like_x(t, x0) return t * x1 + (1 - t) * x0 @@ -95,7 +97,7 @@ def compute_sigma_t(self, t): References ---------- - [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ del t return self.sigma @@ -120,7 +122,7 @@ def sample_xt(self, x0, x1, t, epsilon): References ---------- - [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ mu_t = self.compute_mu_t(x0, x1, t) sigma_t = self.compute_sigma_t(t) @@ -147,7 +149,7 @@ def compute_conditional_flow(self, x0, x1, t, xt): References ---------- - [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ del t, xt return x1 - x0 @@ -183,7 +185,7 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals References ---------- - [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ if t is None: t = torch.rand(x0.shape[0]).type_as(x0) @@ -210,7 +212,7 @@ def compute_lambda(self, t): References ---------- - [4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al. + [4] Tong et al., Simulation-free Schrodinger bridges via score and flow matching, https://arxiv.org/abs/2307.03672 """ sigma_t = self.compute_sigma_t(t) return 2 * sigma_t / (self.sigma**2 + 1e-8) @@ -221,6 +223,8 @@ class ExactOptimalTransportConditionalFlowMatcher(ConditionalFlowMatcher): the OT-CFM methods from [1] and inherits the ConditionalFlowMatcher parent class. It overrides the sample_location_and_conditional_flow. + + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ def __init__(self, sigma: Union[float, int] = 0.0): @@ -262,7 +266,7 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals References ---------- - [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ x0, x1 = self.ot_sampler.sample_plan(x0, x1) return super().sample_location_and_conditional_flow(x0, x1, t, return_noise) @@ -301,7 +305,7 @@ def guided_sample_location_and_conditional_flow( References ---------- - [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1) if return_noise: @@ -317,7 +321,7 @@ class TargetConditionalFlowMatcher(ConditionalFlowMatcher): ConditionalFlowMatcher and override the compute_mu_t, compute_sigma_t and compute_conditional_flow functions in order to compute [2]'s flow matching. - [2] Flow Matching for Generative Modelling, ICLR, Lipman et al. + [2] Lipman et al., Flow Matching for Generative Modelling, ICLR, https://arxiv.org/abs/2210.02747 """ def compute_mu_t(self, x0, x1, t): @@ -337,7 +341,7 @@ def compute_mu_t(self, x0, x1, t): References ---------- - [2] Flow Matching for Generative Modelling, ICLR, Lipman et al. + [2] Lipman et al., Flow Matching for Generative Modelling, ICLR, https://arxiv.org/abs/2210.02747 """ del x0 t = pad_t_like_x(t, x1) @@ -357,7 +361,7 @@ def compute_sigma_t(self, t): References ---------- - [2] Flow Matching for Generative Modelling, ICLR, Lipman et al. + [2] Lipman et al., Flow Matching for Generative Modelling, ICLR, https://arxiv.org/abs/2210.02747 """ return 1 - (1 - self.sigma) * t @@ -381,7 +385,7 @@ def compute_conditional_flow(self, x0, x1, t, xt): References ---------- - [1] Flow Matching for Generative Modelling, ICLR, Lipman et al. + [2] Lipman et al., Flow Matching for Generative Modelling, ICLR, https://arxiv.org/abs/2210.02747 """ del x0 t = pad_t_like_x(t, x1) @@ -394,6 +398,8 @@ class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher): It overrides the compute_sigma_t, compute_conditional_flow and sample_location_and_conditional_flow functions. + + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"): @@ -432,7 +438,7 @@ def compute_sigma_t(self, t): References ---------- - [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ return self.sigma * torch.sqrt(t * (1 - t)) @@ -459,8 +465,7 @@ def compute_conditional_flow(self, x0, x1, t, xt): References ---------- - [1] Improving and Generalizing Flow-Based Generative Models - with minibatch optimal transport, Preprint, Tong et al. + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ t = pad_t_like_x(t, x0) mu_t = self.compute_mu_t(x0, x1, t) @@ -497,7 +502,7 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals References ---------- - [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ x0, x1 = self.ot_sampler.sample_plan(x0, x1) return super().sample_location_and_conditional_flow(x0, x1, t, return_noise) @@ -536,7 +541,7 @@ def guided_sample_location_and_conditional_flow( References ---------- - [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. + [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482 """ x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1) if return_noise: @@ -552,7 +557,7 @@ class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher): ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in order to compute [3]'s trigonometric interpolants. - [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. + [3] Albergo et al., Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, https://arxiv.org/abs/2303.08797 """ def compute_mu_t(self, x0, x1, t): @@ -572,7 +577,7 @@ def compute_mu_t(self, x0, x1, t): References ---------- - [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. + [3] Albergo et al., Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, https://arxiv.org/abs/2303.08797 """ t = pad_t_like_x(t, x0) return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1 @@ -600,7 +605,7 @@ def compute_conditional_flow(self, x0, x1, t, xt): References ---------- - [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. + [3] Albergo et al., Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, https://arxiv.org/abs/2303.08797 """ del xt t = pad_t_like_x(t, x0)