From 87e57695d08ce1044af1f15d0690ee887a957c75 Mon Sep 17 00:00:00 2001 From: Yash Shah Date: Tue, 3 Oct 2023 13:08:05 -0700 Subject: [PATCH] fix pylint bug --- MaxText/layers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/MaxText/layers.py b/MaxText/layers.py index 1bdf28d8e..3bc6cc58d 100644 --- a/MaxText/layers.py +++ b/MaxText/layers.py @@ -25,7 +25,6 @@ from jax.sharding import Mesh from jax.sharding import PartitionSpec as P -import dataclasses import functools import operator from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union @@ -656,6 +655,7 @@ class Embed(nn.Module): dtype: the dtype of the embedding vectors (default: float32). embedding_init: embedding initializer. """ + # pylint: disable=attribute-defined-outside-init config: Config num_embeddings: int features: int @@ -663,7 +663,6 @@ class Embed(nn.Module): dtype: DType = jnp.float32 attend_dtype: Optional[DType] = None embedding_init: Initializer = default_embed_init - embedding: Array = dataclasses.field(init=False) def setup(self): self.embedding = self.param(