Skip to content

Commit

Permalink
Update LitGPT pin (#37)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
4 people authored Mar 22, 2024
1 parent 3ec2f74 commit 4437fa0
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 68 deletions.
2 changes: 1 addition & 1 deletion notebooks/dev_tutorials/fsdp_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1764,7 +1764,7 @@
"%%writefile thunder_fsdp_simple_example.py\n",
"\n",
"# imports\n",
"from thunder.tests.lit_gpt_model import GPT, Config\n",
"from thunder.tests.litgpt_model import GPT, Config\n",
"import torch\n",
"import torch.distributed\n",
"import thunder\n",
Expand Down
18 changes: 9 additions & 9 deletions notebooks/zero_to_thunder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@
}
],
"source": [
"from lit_gpt import GPT\n",
"from thunder.tests.lit_gpt_model import Config\n",
"from litgpt import GPT\n",
"from thunder.tests.litgpt_model import Config\n",
"cfg = Config.from_name('Llama-2-7b-hf')\n",
"cfg.n_layer = 16 # fewer layers\n",
"torch.set_default_dtype(torch.bfloat16)\n",
Expand Down Expand Up @@ -3326,7 +3326,7 @@
],
"source": [
"%%writefile zero_to_thunder_fsdp_simple_example.py\n",
"from thunder.tests.lit_gpt_model import GPT, Config\n",
"from thunder.tests.litgpt_model import GPT, Config\n",
"import os\n",
"import torch, torch.distributed\n",
"import thunder, thunder.distributed\n",
Expand Down Expand Up @@ -3470,7 +3470,7 @@
},
"outputs": [],
"source": [
"import lit_gpt\n",
"import litgpt\n",
"def apply_rope_copy(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n",
" head_size = x.size(-1)\n",
" x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)\n",
Expand All @@ -3493,7 +3493,7 @@
"\n",
"Say we have a function `apply_rope` applying the RoPE transformation in PyTorch.\n",
"\n",
"In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `lit_gpt.model.apply_rope`.\n"
"In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `litgpt.model.apply_rope`.\n"
]
},
{
Expand All @@ -3504,17 +3504,17 @@
"outputs": [],
"source": [
"import torch, thunder\n",
"from thunder.tests.lit_gpt_model import GPT\n",
"from thunder.tests.litgpt_model import GPT\n",
"from thunder import TensorProxy\n",
"\n",
"def apply_rope_impl(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n",
" return lit_gpt.model.apply_rope(x, cos, sin)\n",
" return litgpt.model.apply_rope(x, cos, sin)\n",
"\n",
"def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n",
" return TensorProxy(like=x)\n",
"\n",
"apply_rope = my_ex.register_operator('apply_rope', like=apply_rope_meta, fn=apply_rope_impl,\n",
" replaces=lit_gpt.model.apply_rope)"
" replaces=litgpt.model.apply_rope)"
]
},
{
Expand Down Expand Up @@ -3569,7 +3569,7 @@
"with torch.device('cuda'): m = GPT.from_name('llama2-like'); Q = torch.randn(2, 128, 4096, 16)\n",
"\n",
"def test_apply_rope(x, m):\n",
" return lit_gpt.model.apply_rope(x, m.cos, m.sin)\n",
" return litgpt.model.apply_rope(x, m.cos, m.sin)\n",
"\n",
"thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n",
"\n",
Expand Down
2 changes: 2 additions & 0 deletions requirements/notebooks.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
ipython[all] ==8.22.2

litgpt @ git+https://github.com/Lightning-AI/lit-gpt@24d5eba1724c953b7506edc041a7da1ce226c129
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ expecttest ==0.2.1 # for test_ddp.py
hypothesis ==6.99.10 # for test_ddp.py
numpy # for test_ops.py
einops # for test_einops.py
lit_gpt @ git+https://github.com/Lightning-AI/lit-gpt@f241d94df59d82b2017bfdcd3800ac8779eb45f5
litgpt @ git+https://github.com/Lightning-AI/lit-gpt@24d5eba1724c953b7506edc041a7da1ce226c129
absl-py # thunder/benchmarks/test_benchmark_litgpt.py
pandas # thunder/benchmarks/test_benchmark_litgpt.py
xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py
Expand Down
28 changes: 13 additions & 15 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import thunder.core.devices as Devices
from thunder.core.transforms import grad, clear_grads, populate_grads
import thunder.executors as executors
from thunder.tests import nanogpt_model, hf_bart_self_attn, lit_gpt_model
from thunder.tests import nanogpt_model, hf_bart_self_attn, litgpt_model
from thunder.tests.make_tensor import make_tensor, make_tensor_like
from thunder.tests.lit_gpt_model import Config as LitGPTConfig
from thunder.tests.litgpt_model import Config as LitGPTConfig

# List of all benchmarks
benchmarks: list = []
Expand Down Expand Up @@ -1875,7 +1875,7 @@ class LlamaMLPBenchmark(Benchmark, metaclass=UserFacingBenchmarkMeta):
_args = (
BenchmarkArg(
name="config",
description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the lit_gpt_model.py for details.",
description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the litgpt_model.py for details.",
),
BenchmarkArg(
name="batchdims",
Expand Down Expand Up @@ -1935,7 +1935,7 @@ def make_batch(self) -> tuple[list, dict]:

def fn(self) -> Callable:
module = (
lit_gpt_model.LLaMAMLP(self.config)
litgpt_model.LLaMAMLP(self.config)
.to(device=self.device, dtype=self.tdtype)
.requires_grad_(self.requires_grad)
)
Expand All @@ -1946,7 +1946,7 @@ class LitGPTCausalSelfAttentionBenchmark(Benchmark, metaclass=UserFacingBenchmar
_args = (
BenchmarkArg(
name="config",
description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the lit_gpt_model.py for details.",
description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the litgpt_model.py for details.",
),
BenchmarkArg(
name="batchdims",
Expand Down Expand Up @@ -2005,7 +2005,7 @@ def make_batch(self) -> tuple[list, dict]:

def fn(self) -> Callable:
module = (
lit_gpt_model.CausalSelfAttention(self.config)
litgpt_model.CausalSelfAttention(self.config)
.to(device=self.device, dtype=self.tdtype)
.requires_grad_(self.requires_grad)
)
Expand Down Expand Up @@ -2086,7 +2086,7 @@ def make_batch(self) -> tuple[list, dict]:

def fn(self) -> Callable:
module = (
lit_gpt_model.RMSNorm(self.size, self.dim, self.eps)
litgpt_model.RMSNorm(self.size, self.dim, self.eps)
.to(device=self.device, dtype=self.tdtype)
.requires_grad_(self.requires_grad)
)
Expand Down Expand Up @@ -2168,7 +2168,7 @@ def make_batch(self) -> tuple[list, dict]:

def fn(self) -> Callable:
gpt = (
lit_gpt_model.GPT(self.config)
litgpt_model.GPT(self.config)
.to(device=self.device, dtype=self.model_tdtype)
.requires_grad_(self.requires_grad)
)
Expand Down Expand Up @@ -2199,7 +2199,7 @@ def __init__(self, config, use_apex) -> None:

super().__init__()
self.config = config
self.apply_rope = lit_gpt_model.apply_rope
self.apply_rope = litgpt_model.apply_rope
self.use_apex = use_apex

def forward(
Expand Down Expand Up @@ -2254,7 +2254,7 @@ class LlamaQKVSplitRopeBenchmark(Benchmark, metaclass=UserFacingBenchmarkMeta):
_args = (
BenchmarkArg(
name="config",
description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the lit_gpt_model.py for details.",
description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the litgpt_model.py for details.",
),
BenchmarkArg(
name="batchdims",
Expand Down Expand Up @@ -2610,7 +2610,7 @@ def __init__(
dtype: dtypes.dtype = thunder.bfloat16,
requires_grad: bool = True,
) -> None:
from thunder.tests.lit_gpt_model import Config
from thunder.tests.litgpt_model import Config

litgptconfig = Config.from_name(config) if not isinstance(config, Config) else config
nanogptconfig = NanoGPTConfig(
Expand Down Expand Up @@ -2793,7 +2793,7 @@ def __init__(
# Sets required benchmark parameters
self.devices: list[str] = [device]

self.cos, self.sin = lit_gpt_model.build_rope_cache(
self.cos, self.sin = litgpt_model.build_rope_cache(
seq_len=seq_length, n_elem=self.config.rope_n_elem, device=self.device
)

Expand All @@ -2806,9 +2806,7 @@ def make_batch(self) -> tuple[list, dict]:

def fn(self) -> Callable:
model = (
lit_gpt_model.Block(self.config)
.to(device=self.device, dtype=self.tdtype)
.requires_grad_(self.requires_grad)
litgpt_model.Block(self.config).to(device=self.device, dtype=self.tdtype).requires_grad_(self.requires_grad)
)
return model

Expand Down
2 changes: 1 addition & 1 deletion thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.distributed as torch_dist

import thunder
from thunder.tests.lit_gpt_model import Config, GPT, Block
from thunder.tests.litgpt_model import Config, GPT, Block

from lightning.fabric.utilities.throughput import measure_flops
from lightning.fabric.utilities import Throughput
Expand Down
4 changes: 2 additions & 2 deletions thunder/benchmarks/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
LitGPTBenchmark,
LitGPTConfig,
)
from thunder.tests.lit_gpt_model import name_to_config
from thunder.tests.litgpt_model import name_to_config
from thunder.distributed import FSDPBucketingStrategy
from thunder.distributed import FSDPType

Expand Down Expand Up @@ -299,7 +299,7 @@ def parse_args() -> argparse.Namespace:
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from thunder.benchmarks import get_default_torch_fsdp_executor
from thunder.tests.nanogpt_model import Block as NanoGPTBlock
from thunder.tests.lit_gpt_model import Block as GPTBlock
from thunder.tests.litgpt_model import Block as GPTBlock

sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
auto_wrap_policies = (
Expand Down
2 changes: 1 addition & 1 deletion thunder/benchmarks/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
thunder_sdpa_torch_compile_nvfuser_executor,
)

from thunder.tests.lit_gpt_model import Config as LitGPTConfig
from thunder.tests.litgpt_model import Config as LitGPTConfig


APEX_FUSED_ROPE_AVAILABLE: bool = package_available("fused_rotary_positional_embedding")
Expand Down
47 changes: 23 additions & 24 deletions thunder/tests/lit_gpt_model.py → thunder/tests/litgpt_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Taken from https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/model.py"""
"""Taken from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py"""
import torch
import torch.nn as nn

Expand All @@ -18,9 +18,9 @@
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_class_name="RMSNorm",
norm_eps=1e-6,
_mlp_class="LLaMAMLP",
mlp_class_name="LLaMAMLP",
intermediate_size=1376,
),
dict(
Expand All @@ -34,8 +34,8 @@
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
_mlp_class="LLaMAMLP",
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=11008,
rope_condense_ratio=4,
),
Expand All @@ -49,8 +49,8 @@
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
_mlp_class="LLaMAMLP",
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=1376,
),
dict(
Expand Down Expand Up @@ -87,9 +87,9 @@
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_class_name="RMSNorm",
norm_eps=1e-05,
_mlp_class="LLaMAMLP",
mlp_class_name="LLaMAMLP",
intermediate_size=1376,
rope_base=1000000,
),
Expand All @@ -104,9 +104,9 @@
n_query_groups=8,
parallel_residual=False,
bias=False,
_norm_class="RMSNorm",
norm_class_name="RMSNorm",
norm_eps=1e-05,
_mlp_class="LLaMAMoE",
mlp_class_name="LLaMAMoE",
intermediate_size=224,
rope_base=1000000,
n_expert=8,
Expand Down Expand Up @@ -150,21 +150,20 @@ def reset_parameters(self) -> None:
torch.nn.init.zeros_(self.v)


import lit_gpt
import lit_gpt.rmsnorm
import litgpt

# override for operator workarounds
lit_gpt.model.KVCache = OverridenKVCache
litgpt.model.KVCache = OverridenKVCache
# add the testing configurations
lit_gpt.config.name_to_config.update(name_to_config)
name_to_config.update(lit_gpt.config.name_to_config)
litgpt.config.name_to_config.update(name_to_config)
name_to_config.update(litgpt.config.name_to_config)

# manually expose for backwards compatibility
Config = lit_gpt.Config
GPT = lit_gpt.GPT
RMSNorm = lit_gpt.rmsnorm.RMSNorm
CausalSelfAttention = lit_gpt.model.CausalSelfAttention
LLaMAMLP = lit_gpt.model.LLaMAMLP
build_rope_cache = lit_gpt.model.build_rope_cache
apply_rope = lit_gpt.model.apply_rope
Block = lit_gpt.model.Block
Config = litgpt.Config
GPT = litgpt.GPT
RMSNorm = litgpt.model.RMSNorm
CausalSelfAttention = litgpt.model.CausalSelfAttention
LLaMAMLP = litgpt.model.LLaMAMLP
build_rope_cache = litgpt.model.build_rope_cache
apply_rope = litgpt.model.apply_rope
Block = litgpt.model.Block
14 changes: 7 additions & 7 deletions thunder/tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,26 +1337,26 @@ def foo(a):

def foo():
# test relative import
from .lit_gpt_model import Config
from .litgpt_model import Config

return Config

assert jit(foo)() is foo()

def foo():
# test relative import
from . import lit_gpt_model
from . import litgpt_model

return lit_gpt_model.Config
return litgpt_model.Config

assert jit(foo)() is foo()

# reload is implemented using exec of the module
from . import lit_gpt_model
from . import litgpt_model
import importlib

importlib.reload(lit_gpt_model)
assert hasattr(lit_gpt_model, "GPT")
importlib.reload(litgpt_model)
assert hasattr(litgpt_model, "GPT")


def test_locals_lookaside(jit):
Expand Down Expand Up @@ -3071,7 +3071,7 @@ def test_nanogpt(jit):

def test_litgpt(jit):
from thunder.benchmarks import LitGPTBenchmark
from thunder.tests.lit_gpt_model import Config
from thunder.tests.litgpt_model import Config

cfg: Config = Config.from_name("gpt-neox-like")
bench = LitGPTBenchmark(config=cfg, device="cpu", dtype=torch.bfloat16, requires_grad=True)
Expand Down
Loading

0 comments on commit 4437fa0

Please sign in to comment.