Skip to content

Commit

Permalink
switch to jaxtyping
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 5, 2024
1 parent c1a1788 commit 88c7f23
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
7 changes: 3 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'taylor-series-linear-attention',
packages = find_packages(exclude=[]),
version = '0.1.9',
version = '0.1.11',
license='MIT',
description = 'Taylor Series Linear Attention',
author = 'Phil Wang',
Expand All @@ -17,10 +17,9 @@
],
install_requires=[
'einops>=0.7.0',
'einx',
'jaxtyping',
'rotary-embedding-torch>=0.5.3',
'torch>=2.0',
'torchtyping'
'torch>=2.0'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
12 changes: 6 additions & 6 deletions taylor_series_linear_attention/attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import importlib
from functools import partial
from collections import namedtuple
Expand All @@ -10,8 +11,7 @@
from einops import rearrange, pack, unpack
from einops.layers.torch import Rearrange

from typing import Optional
from torchtyping import TensorType
from taylor_series_linear_attention.tensor_typing import Float, Int, Bool

from rotary_embedding_torch import RotaryEmbedding

Expand Down Expand Up @@ -145,11 +145,11 @@ def __init__(

def forward(
self,
x: TensorType['batch', 'seq', 'dim', float],
mask: Optional[TensorType['batch', 'seq', bool]] = None,
context: Optional[TensorType['batch', 'target_seq', 'dim', float]] = None,
x: Float['batch seq dim'],
mask: Bool['batch seq'] | None = None,
context: Float['batch target_seq dim'] | None = None,
eps: float = 1e-5,
cache: Optional[Cache] = None,
cache: Cache | None = None,
return_cache = False
):
"""
Expand Down
26 changes: 26 additions & 0 deletions taylor_series_linear_attention/tensor_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from torch import Tensor

from jaxtyping import (
Float,
Int,
Bool
)

# jaxtyping is a misnomer, works for pytorch

class TorchTyping:
def __init__(self, abstract_dtype):
self.abstract_dtype = abstract_dtype

def __getitem__(self, shapes: str):
return self.abstract_dtype[Tensor, shapes]

Float = TorchTyping(Float)
Int = TorchTyping(Int)
Bool = TorchTyping(Bool)

__all__ = [
Float,
Int,
Bool
]

0 comments on commit 88c7f23

Please sign in to comment.