diff --git a/README.md b/README.md index 69c58aa9..624c1464 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,7 @@ git clone https://github.com/salesforce/LAVIS.git cd LAVIS pip install -e . ``` +If you are using arm cpu(for example Mac with Apple Silicon), please use `requirements_arm.txt` instead of `requirements.txt` to install the dependencies. ## Getting Started ### Model Zoo diff --git a/lavis/models/eva_vit.py b/lavis/models/eva_vit.py index 5b80b820..dc167bce 100644 --- a/lavis/models/eva_vit.py +++ b/lavis/models/eva_vit.py @@ -52,10 +52,18 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay self.drop = nn.Dropout(drop) def forward(self, x): + if self.fc1.weight.dtype == torch.float16: + x = x.half() + elif self.fc1.weight.dtype == torch.float32: + x = x.float() x = self.fc1(x) x = self.act(x) # x = self.drop(x) # commit this for the orignal BERT implement + if self.fc2.weight.dtype == torch.float16: + x = x.half() + elif self.fc2.weight.dtype == torch.float32: + x = x.float() x = self.fc2(x) x = self.drop(x) return x @@ -143,6 +151,10 @@ def forward(self, x, rel_pos_bias=None): attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + if self.proj.weight.dtype == torch.float16: + x = x.half() + elif self.proj.weight.dtype == torch.float32: + x = x.float() x = self.proj(x) x = self.proj_drop(x) return x @@ -200,6 +212,10 @@ def forward(self, x, **kwargs): # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + if self.proj.weight.dtype == torch.float16: + x = x.half() + elif self.proj.weight.dtype == torch.float32: + x = x.float() x = self.proj(x).flatten(2).transpose(1, 2) return x diff --git a/requirements_arm.txt b/requirements_arm.txt new file mode 100644 index 00000000..54adcb39 --- /dev/null +++ b/requirements_arm.txt @@ -0,0 +1,29 @@ +contexttimer +eva-decord +diffusers<=0.16.0 +einops>=0.4.1 +fairscale==0.4.4 +ftfy +iopath +ipython +omegaconf +opencv-python-headless==4.5.5.64 +opendatasets +packaging +pandas +plotly +pre-commit +pycocoevalcap +pycocotools +python-magic +scikit-image +sentencepiece +spacy +streamlit +timm==0.4.12 +torch>=1.10.0 +torchvision +tqdm +transformers>=4.28.0 +webdataset +wheel