From 4579d5e90b3decb9aa97dd78d6b5cf453f50e5f9 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Fri, 15 Aug 2025 13:50:23 +0200
Subject: [PATCH] updated references to include URLs
---
torchcfm/conditional_flow_matching.py | 45 +++++++++++++++------------
1 file changed, 25 insertions(+), 20 deletions(-)
diff --git a/torchcfm/conditional_flow_matching.py b/torchcfm/conditional_flow_matching.py
index b71b6bd2..e34a2613 100644
--- a/torchcfm/conditional_flow_matching.py
+++ b/torchcfm/conditional_flow_matching.py
@@ -47,6 +47,8 @@ class ConditionalFlowMatcher:
- Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
- conditional flow matching ut(x1|x0) = x1 - x0
- score function $\nabla log p_t(x|x0, x1)$
+
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
def __init__(self, sigma: Union[float, int] = 0.0):
@@ -76,7 +78,7 @@ def compute_mu_t(self, x0, x1, t):
References
----------
- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
t = pad_t_like_x(t, x0)
return t * x1 + (1 - t) * x0
@@ -95,7 +97,7 @@ def compute_sigma_t(self, t):
References
----------
- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
del t
return self.sigma
@@ -120,7 +122,7 @@ def sample_xt(self, x0, x1, t, epsilon):
References
----------
- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
mu_t = self.compute_mu_t(x0, x1, t)
sigma_t = self.compute_sigma_t(t)
@@ -147,7 +149,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):
References
----------
- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
del t, xt
return x1 - x0
@@ -183,7 +185,7 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals
References
----------
- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
if t is None:
t = torch.rand(x0.shape[0]).type_as(x0)
@@ -210,7 +212,7 @@ def compute_lambda(self, t):
References
----------
- [4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al.
+ [4] Tong et al., Simulation-free Schrodinger bridges via score and flow matching, https://arxiv.org/abs/2307.03672
"""
sigma_t = self.compute_sigma_t(t)
return 2 * sigma_t / (self.sigma**2 + 1e-8)
@@ -221,6 +223,8 @@ class ExactOptimalTransportConditionalFlowMatcher(ConditionalFlowMatcher):
the OT-CFM methods from [1] and inherits the ConditionalFlowMatcher parent class.
It overrides the sample_location_and_conditional_flow.
+
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
def __init__(self, sigma: Union[float, int] = 0.0):
@@ -262,7 +266,7 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals
References
----------
- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
return super().sample_location_and_conditional_flow(x0, x1, t, return_noise)
@@ -301,7 +305,7 @@ def guided_sample_location_and_conditional_flow(
References
----------
- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)
if return_noise:
@@ -317,7 +321,7 @@ class TargetConditionalFlowMatcher(ConditionalFlowMatcher):
ConditionalFlowMatcher and override the compute_mu_t, compute_sigma_t and
compute_conditional_flow functions in order to compute [2]'s flow matching.
- [2] Flow Matching for Generative Modelling, ICLR, Lipman et al.
+ [2] Lipman et al., Flow Matching for Generative Modelling, ICLR, https://arxiv.org/abs/2210.02747
"""
def compute_mu_t(self, x0, x1, t):
@@ -337,7 +341,7 @@ def compute_mu_t(self, x0, x1, t):
References
----------
- [2] Flow Matching for Generative Modelling, ICLR, Lipman et al.
+ [2] Lipman et al., Flow Matching for Generative Modelling, ICLR, https://arxiv.org/abs/2210.02747
"""
del x0
t = pad_t_like_x(t, x1)
@@ -357,7 +361,7 @@ def compute_sigma_t(self, t):
References
----------
- [2] Flow Matching for Generative Modelling, ICLR, Lipman et al.
+ [2] Lipman et al., Flow Matching for Generative Modelling, ICLR, https://arxiv.org/abs/2210.02747
"""
return 1 - (1 - self.sigma) * t
@@ -381,7 +385,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):
References
----------
- [1] Flow Matching for Generative Modelling, ICLR, Lipman et al.
+ [2] Lipman et al., Flow Matching for Generative Modelling, ICLR, https://arxiv.org/abs/2210.02747
"""
del x0
t = pad_t_like_x(t, x1)
@@ -394,6 +398,8 @@ class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher):
It overrides the compute_sigma_t, compute_conditional_flow and
sample_location_and_conditional_flow functions.
+
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"):
@@ -432,7 +438,7 @@ def compute_sigma_t(self, t):
References
----------
- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
return self.sigma * torch.sqrt(t * (1 - t))
@@ -459,8 +465,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):
References
----------
- [1] Improving and Generalizing Flow-Based Generative Models
- with minibatch optimal transport, Preprint, Tong et al.
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
t = pad_t_like_x(t, x0)
mu_t = self.compute_mu_t(x0, x1, t)
@@ -497,7 +502,7 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals
References
----------
- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
return super().sample_location_and_conditional_flow(x0, x1, t, return_noise)
@@ -536,7 +541,7 @@ def guided_sample_location_and_conditional_flow(
References
----------
- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
+ [1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)
if return_noise:
@@ -552,7 +557,7 @@ class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher):
ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in
order to compute [3]'s trigonometric interpolants.
- [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
+ [3] Albergo et al., Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, https://arxiv.org/abs/2303.08797
"""
def compute_mu_t(self, x0, x1, t):
@@ -572,7 +577,7 @@ def compute_mu_t(self, x0, x1, t):
References
----------
- [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
+ [3] Albergo et al., Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, https://arxiv.org/abs/2303.08797
"""
t = pad_t_like_x(t, x0)
return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1
@@ -600,7 +605,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):
References
----------
- [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
+ [3] Albergo et al., Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, https://arxiv.org/abs/2303.08797
"""
del xt
t = pad_t_like_x(t, x0)