Skip to content

Commit

Permalink
NoOp Model (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
Landanjs authored Jun 20, 2024
1 parent 93a5469 commit ae4920f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
2 changes: 2 additions & 0 deletions diffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion,
discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl)
from diffusion.models.noop import NoOpModel
from diffusion.models.pixel_diffusion import PixelDiffusion
from diffusion.models.stable_diffusion import StableDiffusion

Expand All @@ -13,6 +14,7 @@
'build_diffusers_autoencoder',
'continuous_pixel_diffusion',
'discrete_pixel_diffusion',
'NoOpModel',
'PixelDiffusion',
'stable_diffusion_2',
'stable_diffusion_xl',
Expand Down
49 changes: 49 additions & 0 deletions diffusion/models/noop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""NoOpModel algorithm and class."""

from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from composer.models.base import ComposerModel
from torchmetrics import Metric

from diffusion.models.text_encoder import MultiTokenizer


class NoOpModel(ComposerModel):
"""No-op model used to measure dataloader throughput.
Args:
tokenizer_names (str, Tuple[str, ...]): HuggingFace name(s) of the tokenizer(s) to load.
Default: ``('stabilityai/stable-diffusion-xl-base-1.0/tokenizer',
'stabilityai/stable-diffusion-xl-base-1.0/tokenizer_2')``.
"""

def __init__(
self,
tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer',
'stabilityai/stable-diffusion-xl-base-1.0/tokenizer_2'),
):
super().__init__()
self.weight = torch.nn.Linear(in_features=1, out_features=16)
self.tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names)

def loss(self, outputs: torch.Tensor, batch):
y = torch.randn_like(self.weight.weight)
return F.mse_loss(outputs, y)

def forward(self, batch):
input = torch.randn_like(self.weight.weight).sum().unsqueeze(0)
return self.weight(input)

def get_metrics(self, is_train: bool) -> Dict[str, Metric]:
return {}

def eval_forward(self, batch, outputs: Optional[Any] = None):
return self.forward(batch)

def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
pass

0 comments on commit ae4920f

Please sign in to comment.