Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 25 additions & 20 deletions torchcfm/conditional_flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class ConditionalFlowMatcher:
- Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
- conditional flow matching ut(x1|x0) = x1 - x0
- score function $\nabla log p_t(x|x0, x1)$

[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""

def __init__(self, sigma: Union[float, int] = 0.0):
Expand Down Expand Up @@ -76,7 +78,7 @@ def compute_mu_t(self, x0, x1, t):

References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
t = pad_t_like_x(t, x0)
return t * x1 + (1 - t) * x0
Expand All @@ -95,7 +97,7 @@ def compute_sigma_t(self, t):

References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
del t
return self.sigma
Expand All @@ -120,7 +122,7 @@ def sample_xt(self, x0, x1, t, epsilon):

References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
mu_t = self.compute_mu_t(x0, x1, t)
sigma_t = self.compute_sigma_t(t)
Expand All @@ -147,7 +149,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):

References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
del t, xt
return x1 - x0
Expand Down Expand Up @@ -183,7 +185,7 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals

References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
if t is None:
t = torch.rand(x0.shape[0]).type_as(x0)
Expand All @@ -210,7 +212,7 @@ def compute_lambda(self, t):

References
----------
[4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al.
[4] Tong et al., Simulation-free Schrodinger bridges via score and flow matching, https://arxiv.org/abs/2307.03672
"""
sigma_t = self.compute_sigma_t(t)
return 2 * sigma_t / (self.sigma**2 + 1e-8)
Expand All @@ -221,6 +223,8 @@ class ExactOptimalTransportConditionalFlowMatcher(ConditionalFlowMatcher):
the OT-CFM methods from [1] and inherits the ConditionalFlowMatcher parent class.

It overrides the sample_location_and_conditional_flow.

[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""

def __init__(self, sigma: Union[float, int] = 0.0):
Expand Down Expand Up @@ -262,7 +266,7 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals

References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
return super().sample_location_and_conditional_flow(x0, x1, t, return_noise)
Expand Down Expand Up @@ -301,7 +305,7 @@ def guided_sample_location_and_conditional_flow(

References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)
if return_noise:
Expand All @@ -317,7 +321,7 @@ class TargetConditionalFlowMatcher(ConditionalFlowMatcher):
ConditionalFlowMatcher and override the compute_mu_t, compute_sigma_t and
compute_conditional_flow functions in order to compute [2]'s flow matching.

[2] Flow Matching for Generative Modelling, ICLR, Lipman et al.
[2] Lipman et al., Flow Matching for Generative Modelling, ICLR, https://arxiv.org/abs/2210.02747
"""

def compute_mu_t(self, x0, x1, t):
Expand All @@ -337,7 +341,7 @@ def compute_mu_t(self, x0, x1, t):

References
----------
[2] Flow Matching for Generative Modelling, ICLR, Lipman et al.
[2] Lipman et al., Flow Matching for Generative Modelling, ICLR, https://arxiv.org/abs/2210.02747
"""
del x0
t = pad_t_like_x(t, x1)
Expand All @@ -357,7 +361,7 @@ def compute_sigma_t(self, t):

References
----------
[2] Flow Matching for Generative Modelling, ICLR, Lipman et al.
[2] Lipman et al., Flow Matching for Generative Modelling, ICLR, https://arxiv.org/abs/2210.02747
"""
return 1 - (1 - self.sigma) * t

Expand All @@ -381,7 +385,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):

References
----------
[1] Flow Matching for Generative Modelling, ICLR, Lipman et al.
[2] Lipman et al., Flow Matching for Generative Modelling, ICLR, https://arxiv.org/abs/2210.02747
"""
del x0
t = pad_t_like_x(t, x1)
Expand All @@ -394,6 +398,8 @@ class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher):

It overrides the compute_sigma_t, compute_conditional_flow and
sample_location_and_conditional_flow functions.

[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""

def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"):
Expand Down Expand Up @@ -432,7 +438,7 @@ def compute_sigma_t(self, t):

References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
return self.sigma * torch.sqrt(t * (1 - t))

Expand All @@ -459,8 +465,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):

References
----------
[1] Improving and Generalizing Flow-Based Generative Models
with minibatch optimal transport, Preprint, Tong et al.
[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
t = pad_t_like_x(t, x0)
mu_t = self.compute_mu_t(x0, x1, t)
Expand Down Expand Up @@ -497,7 +502,7 @@ def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=Fals

References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
return super().sample_location_and_conditional_flow(x0, x1, t, return_noise)
Expand Down Expand Up @@ -536,7 +541,7 @@ def guided_sample_location_and_conditional_flow(

References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
[1] Tong et al., Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint https://arxiv.org/abs/2302.00482
"""
x0, x1, y0, y1 = self.ot_sampler.sample_plan_with_labels(x0, x1, y0, y1)
if return_noise:
Expand All @@ -552,7 +557,7 @@ class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher):
ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in
order to compute [3]'s trigonometric interpolants.

[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
[3] Albergo et al., Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, https://arxiv.org/abs/2303.08797
"""

def compute_mu_t(self, x0, x1, t):
Expand All @@ -572,7 +577,7 @@ def compute_mu_t(self, x0, x1, t):

References
----------
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
[3] Albergo et al., Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, https://arxiv.org/abs/2303.08797
"""
t = pad_t_like_x(t, x0)
return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1
Expand Down Expand Up @@ -600,7 +605,7 @@ def compute_conditional_flow(self, x0, x1, t, xt):

References
----------
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
[3] Albergo et al., Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, https://arxiv.org/abs/2303.08797
"""
del xt
t = pad_t_like_x(t, x0)
Expand Down