diff --git a/parallelformers/policies/base/auto.py b/parallelformers/policies/base/auto.py index 433e4c3..98f3c3a 100644 --- a/parallelformers/policies/base/auto.py +++ b/parallelformers/policies/base/auto.py @@ -12,305 +12,406 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import suppress from typing import List, Union from torch import nn -from transformers.models.albert.modeling_albert import AlbertPreTrainedModel -from transformers.models.bart.modeling_bart import BartPretrainedModel -from transformers.models.bert.modeling_bert import BertPreTrainedModel -from transformers.models.bert_generation.modeling_bert_generation import ( - BertGenerationPreTrainedModel, -) -from transformers.models.big_bird.modeling_big_bird import ( - BigBirdPreTrainedModel, -) -from transformers.models.bigbird_pegasus.modeling_bigbird_pegasus import ( - BigBirdPegasusPreTrainedModel, -) -from transformers.models.blenderbot.modeling_blenderbot import ( - BlenderbotPreTrainedModel, -) -from transformers.models.blenderbot_small.modeling_blenderbot_small import ( - BlenderbotSmallPreTrainedModel, -) -from transformers.models.clip.modeling_clip import CLIPPreTrainedModel -from transformers.models.convbert.modeling_convbert import ( - ConvBertPreTrainedModel, -) -from transformers.models.ctrl.modeling_ctrl import CTRLPreTrainedModel -from transformers.models.deberta.modeling_deberta import DebertaPreTrainedModel -from transformers.models.deberta_v2.modeling_deberta_v2 import ( - DebertaV2PreTrainedModel, -) -from transformers.models.deit.modeling_deit import DeiTPreTrainedModel -from transformers.models.detr.modeling_detr import DetrPreTrainedModel -from transformers.models.distilbert.modeling_distilbert import ( - DistilBertPreTrainedModel, -) -from transformers.models.dpr.modeling_dpr import ( - DPRPretrainedContextEncoder, - DPRPretrainedQuestionEncoder, - DPRPretrainedReader, -) -from transformers.models.electra.modeling_electra import ElectraPreTrainedModel -from transformers.models.fsmt.modeling_fsmt import PretrainedFSMTModel -from transformers.models.funnel.modeling_funnel import FunnelPreTrainedModel -from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel -from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoPreTrainedModel -from transformers.models.hubert.modeling_hubert import HubertPreTrainedModel -from transformers.models.ibert.modeling_ibert import IBertPreTrainedModel -from transformers.models.layoutlm.modeling_layoutlm import ( - LayoutLMPreTrainedModel, -) -from transformers.models.led.modeling_led import LEDPreTrainedModel -from transformers.models.longformer.modeling_longformer import ( - LongformerPreTrainedModel, -) -from transformers.models.luke.modeling_luke import LukePreTrainedModel -from transformers.models.lxmert.modeling_lxmert import LxmertPreTrainedModel -from transformers.models.m2m_100.modeling_m2m_100 import M2M100PreTrainedModel -from transformers.models.marian.modeling_marian import MarianPreTrainedModel -from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel -from transformers.models.mobilebert.modeling_mobilebert import ( - MobileBertPreTrainedModel, -) -from transformers.models.mpnet.modeling_mpnet import MPNetPreTrainedModel -from transformers.models.openai.modeling_openai import OpenAIGPTPreTrainedModel -from transformers.models.pegasus.modeling_pegasus import PegasusPreTrainedModel -from transformers.models.prophetnet.modeling_prophetnet import ( - ProphetNetPreTrainedModel, -) -from transformers.models.reformer.modeling_reformer import ( - ReformerPreTrainedModel, -) -from transformers.models.retribert.modeling_retribert import ( - RetriBertPreTrainedModel, -) -from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel -from transformers.models.roformer.modeling_roformer import ( - RoFormerPreTrainedModel, -) -from transformers.models.speech_to_text.modeling_speech_to_text import ( - Speech2TextPreTrainedModel, -) -from transformers.models.t5.modeling_t5 import T5PreTrainedModel -from transformers.models.tapas.modeling_tapas import TapasPreTrainedModel -from transformers.models.transfo_xl.modeling_transfo_xl import ( - TransfoXLPreTrainedModel, -) -from transformers.models.visual_bert.modeling_visual_bert import ( - VisualBertPreTrainedModel, -) -from transformers.models.vit.modeling_vit import ViTPreTrainedModel -from transformers.models.wav2vec2.modeling_wav2vec2 import ( - Wav2Vec2PreTrainedModel, -) -from transformers.models.xlm.modeling_xlm import XLMPreTrainedModel -from transformers.models.xlnet.modeling_xlnet import XLNetPreTrainedModel - -from parallelformers.policies.albert import AlbertPolicy -from parallelformers.policies.bart import BartDecoderPolicy, BartEncoderPolicy + from parallelformers.policies.base import Policy -from parallelformers.policies.bert import BertPolicy -from parallelformers.policies.bigbird import BigBirdPolicy -from parallelformers.policies.bigbird_pegasus import ( - BigBirdPegasusDecoderPolicy, - BigBirdPegasusEncoderPolicy, -) -from parallelformers.policies.blenderbot import ( - BlenderbotDecoderPolicy, - BlenderbotEncoderPolicy, -) -from parallelformers.policies.blenderbot_small import ( - BlenderbotSmallDecoderPolicy, - BlenderbotSmallEncoderPolicy, -) -from parallelformers.policies.clip import ( - CLIPLayerPolicy, - CLIPTextPolicy, - CLIPVisionPolicy, -) -from parallelformers.policies.convbert import ConvBertPolicy -from parallelformers.policies.ctrl import CTRLPolicy -from parallelformers.policies.deberta import DebertaPolicy -from parallelformers.policies.deberta_v2 import DebertaV2Policy -from parallelformers.policies.deit import DeiTPolicy -from parallelformers.policies.detr import DetrDecoderPolicy, DetrEncoderPolicy -from parallelformers.policies.distil_bert import DistilBertPolicy -from parallelformers.policies.electra import ElectraPolicy -from parallelformers.policies.fsmt import FSMTDecoderPolicy, FSMTEncoderPolicy -from parallelformers.policies.funnel import FunnelPolicy -from parallelformers.policies.gpt2 import GPT2Policy -from parallelformers.policies.gpt_neo import GPTNeoPolicy -from parallelformers.policies.hubert import HubertPolicy -from parallelformers.policies.ibert import IBertPolicy -from parallelformers.policies.layoutlm import LayoutLMPolicy -from parallelformers.policies.led import LEDDecoderPolicy, LEDEncoderPolicy -from parallelformers.policies.longformer import LongformerPolicy -from parallelformers.policies.luke import LukePolicy -from parallelformers.policies.lxmert import LxmertPolicy -from parallelformers.policies.m2m_100 import ( - M2M100DecoderPolicy, - M2M100EncoderPolicy, -) -from parallelformers.policies.marian import ( - MarianDecoderPolicy, - MarianEncoderPolicy, -) -from parallelformers.policies.mbart import ( - MBartDecoderPolicy, - MBartEncoderPolicy, -) -from parallelformers.policies.mobilebert import MobileBertPolicy -from parallelformers.policies.mpnet import MPNetEncoderPolicy, MPNetLayerPolicy -from parallelformers.policies.openai import OpenAIGPTPolicy -from parallelformers.policies.pegasus import ( - PegasusDecoderPolicy, - PegasusEncoderPolicy, -) -from parallelformers.policies.prophetnet import ( - ProphetNetDecoderPolicy, - ProphetNetEncoderPolicy, -) -from parallelformers.policies.reformer import ReformerPolicy -from parallelformers.policies.roberta import RobertaPolicy -from parallelformers.policies.roformer import RoformerPolicy -from parallelformers.policies.speech_to_text import ( - Speech2TextDecoderPolicy, - Speech2TextEncoderPolicy, -) -from parallelformers.policies.t5 import T5Policy -from parallelformers.policies.tapas import TapasPolicy -from parallelformers.policies.transfo_xl import TransfoXLPolicy -from parallelformers.policies.visual_bert import VisualBertPolicy -from parallelformers.policies.vit import ViTPolicy -from parallelformers.policies.wav2vec import Wav2VecPolicy -from parallelformers.policies.xlm import XLMAttentionPolicy, XLMMLPPolicy -from parallelformers.policies.xlnet import XLNetPolicy class AutoPolicy: """Class for finds automatically appropriate policies for the current model""" - def get_policy(self, model: nn.Module) -> Union[List[Policy], None]: - """ - Find appropriate policies for the current model + def __init__(self): + self.builtin_policies = {} - Args: - model (nn.Module): model to parallelize + with suppress(Exception): + from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoPreTrainedModel + from parallelformers.policies.gpt_neo import GPTNeoPolicy + self.builtin_policies[GPTNeoPreTrainedModel] = [ + GPTNeoPolicy, + ] - Returns: - Union[List[Policy], None]: appropriate policies or none - """ - for k, v in self.available().items(): - if isinstance(model, k): - return v - - return None + with suppress(Exception): + from transformers.models.bert.modeling_bert import BertPreTrainedModel + from parallelformers.policies.bert import BertPolicy + self.builtin_policies[BertPreTrainedModel] = [ + BertPolicy, + ] - @staticmethod - def available(): - """Dictionary of available models and policies""" - return { - GPTNeoPreTrainedModel: [GPTNeoPolicy], - BertPreTrainedModel: [BertPolicy], - BartPretrainedModel: [ + with suppress(Exception): + from transformers.models.bart.modeling_bart import BartPretrainedModel + from parallelformers.policies.bart import BartEncoderPolicy, BartDecoderPolicy + self.builtin_policies[BartPretrainedModel] = [ BartEncoderPolicy, BartDecoderPolicy, - ], - BlenderbotPreTrainedModel: [ + ] + + with suppress(Exception): + from transformers.models.blenderbot.modeling_blenderbot import BlenderbotPreTrainedModel + from parallelformers.policies.blenderbot import BlenderbotEncoderPolicy, BlenderbotDecoderPolicy + self.builtin_policies[BlenderbotPreTrainedModel] = [ BlenderbotEncoderPolicy, BlenderbotDecoderPolicy, - ], - DebertaPreTrainedModel: [DebertaPolicy], - TransfoXLPreTrainedModel: [TransfoXLPolicy], - RobertaPreTrainedModel: [RobertaPolicy], - AlbertPreTrainedModel: [AlbertPolicy], - GPT2PreTrainedModel: [GPT2Policy], - CTRLPreTrainedModel: [CTRLPolicy], - DebertaV2PreTrainedModel: [DebertaV2Policy], - OpenAIGPTPreTrainedModel: [OpenAIGPTPolicy], - ElectraPreTrainedModel: [ElectraPolicy], - BlenderbotSmallPreTrainedModel: [ - BlenderbotSmallEncoderPolicy, - BlenderbotSmallDecoderPolicy, - ], - DistilBertPreTrainedModel: [DistilBertPolicy], - ConvBertPreTrainedModel: [ConvBertPolicy], - BertGenerationPreTrainedModel: [BertPolicy], - BigBirdPreTrainedModel: [BigBirdPolicy], - BigBirdPegasusPreTrainedModel: [ + ] + + with suppress(Exception): + from transformers.models.deberta.modeling_deberta import DebertaPreTrainedModel + from parallelformers.policies.deberta import DebertaPolicy + self.builtin_policies[DebertaPreTrainedModel] = [ + DebertaPolicy, + ] + + with suppress(Exception): + from transformers.models.transfo_xl.modeling_transfo_xl import TransfoXLPreTrainedModel + from parallelformers.policies.transfo_xl import TransfoXLPolicy + self.builtin_policies[TransfoXLPreTrainedModel] = [ + TransfoXLPolicy, + ] + + with suppress(Exception): + from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel + from parallelformers.policies.roberta import RobertaPolicy + self.builtin_policies[RobertaPreTrainedModel] = [ + RobertaPolicy, + ] + + with suppress(Exception): + from transformers.models.albert.modeling_albert import AlbertPreTrainedModel + from parallelformers.policies.albert import AlbertPolicy + self.builtin_policies[AlbertPreTrainedModel] = [ + AlbertPolicy, + ] + + with suppress(Exception): + from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel + from parallelformers.policies.gpt2 import GPT2Policy + self.builtin_policies[GPT2PreTrainedModel] = [ + GPT2Policy, + ] + + with suppress(Exception): + from transformers.models.ctrl.modeling_ctrl import CTRLPreTrainedModel + from parallelformers.policies.ctrl import CTRLPolicy + self.builtin_policies[CTRLPreTrainedModel] = [ + CTRLPolicy, + ] + + with suppress(Exception): + from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2PreTrainedModel + from parallelformers.policies.deberta_v2 import DebertaV2Policy + self.builtin_policies[DebertaV2PreTrainedModel] = [ + DebertaV2Policy, + ] + + with suppress(Exception): + from transformers.models.openai.modeling_openai import OpenAIGPTPreTrainedModel + from parallelformers.policies.openai import OpenAIGPTPolicy + self.builtin_policies[OpenAIGPTPreTrainedModel] = [ + OpenAIGPTPolicy, + ] + + with suppress(Exception): + from transformers.models.electra.modeling_electra import ElectraPreTrainedModel + from parallelformers.policies.electra import ElectraPolicy + self.builtin_policies[ElectraPreTrainedModel] = [ + ElectraPolicy, + ] + + with suppress(Exception): + from transformers.models.blenderbot_small.modeling_blenderbot_small import BlenderbotSmallPreTrainedModel + from parallelformers.policies.blenderbot_small import BlenderbotSmallEncoderPolicy, BlenderbotSmallDecoderPolicy + self.builtin_policies[BlenderbotSmallPreTrainedModel] = [ + BlenderbotSmallEncoderPolicy, BlenderbotSmallDecoderPolicy + ] + + with suppress(Exception): + from transformers.models.distilbert.modeling_distilbert import DistilBertPreTrainedModel + from parallelformers.policies.distil_bert import DistilBertPolicy + self.builtin_policies[DistilBertPreTrainedModel] = [ + DistilBertPolicy, + ] + + with suppress(Exception): + from transformers.models.convbert.modeling_convbert import ConvBertPreTrainedModel + from parallelformers.policies.convbert import ConvBertPolicy + self.builtin_policies[ConvBertPreTrainedModel] = [ + ConvBertPolicy, + ] + + with suppress(Exception): + from transformers.models.bert_generation.modeling_bert_generation import BertGenerationPreTrainedModel + from parallelformers.policies.bert import BertPolicy + self.builtin_policies[BertGenerationPreTrainedModel] = [ + BertPolicy, + ] + + with suppress(Exception): + from transformers.models.big_bird.modeling_big_bird import BigBirdPreTrainedModel + from parallelformers.policies.bigbird import BigBirdPolicy + self.builtin_policies[BigBirdPreTrainedModel] = [ + BigBirdPolicy, + ] + + with suppress(Exception): + from transformers.models.bigbird_pegasus.modeling_bigbird_pegasus import BigBirdPegasusPreTrainedModel + from parallelformers.policies.bigbird_pegasus import BigBirdPegasusEncoderPolicy, BigBirdPegasusDecoderPolicy + self.builtin_policies[BigBirdPegasusPreTrainedModel] = [ BigBirdPegasusEncoderPolicy, BigBirdPegasusDecoderPolicy, - ], - ViTPreTrainedModel: [ViTPolicy], - DeiTPreTrainedModel: [DeiTPolicy], - MBartPreTrainedModel: [ + ] + + with suppress(Exception): + from transformers.models.vit.modeling_vit import ViTPreTrainedModel + from parallelformers.policies.vit import ViTPolicy + self.builtin_policies[ViTPreTrainedModel] = [ + ViTPolicy, + ] + + with suppress(Exception): + from transformers.models.deit.modeling_deit import DeiTPreTrainedModel + from parallelformers.policies.deit import DeiTPolicy + self.builtin_policies[DeiTPreTrainedModel] = [DeiTPolicy] + + with suppress(Exception): + from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel + from parallelformers.policies.mbart import MBartEncoderPolicy, MBartDecoderPolicy + self.builtin_policies[MBartPreTrainedModel] = [ MBartEncoderPolicy, MBartDecoderPolicy, - ], - T5PreTrainedModel: [T5Policy], - PegasusPreTrainedModel: [ + ] + + with suppress(Exception): + from transformers.models.t5.modeling_t5 import T5PreTrainedModel + from parallelformers.policies.t5 import T5Policy + self.builtin_policies[T5PreTrainedModel] = [ + T5Policy, + ] + + with suppress(Exception): + from transformers.models.pegasus.modeling_pegasus import PegasusPreTrainedModel + from parallelformers.policies.pegasus import PegasusEncoderPolicy, PegasusDecoderPolicy + self.builtin_policies[PegasusPreTrainedModel] = [ PegasusEncoderPolicy, PegasusDecoderPolicy, - ], - PretrainedFSMTModel: [ + ] + + with suppress(Exception): + from transformers.models.fsmt.modeling_fsmt import PretrainedFSMTModel + from parallelformers.policies.fsmt import FSMTEncoderPolicy, FSMTDecoderPolicy + self.builtin_policies[PretrainedFSMTModel] = [ FSMTEncoderPolicy, FSMTDecoderPolicy, - ], - XLMPreTrainedModel: [ + ] + + with suppress(Exception): + from transformers.models.xlm.modeling_xlm import XLMPreTrainedModel + from parallelformers.policies.xlm import XLMAttentionPolicy, XLMMLPPolicy + self.builtin_policies[XLMPreTrainedModel] = [ XLMAttentionPolicy, XLMMLPPolicy, - ], - M2M100PreTrainedModel: [ + ] + + with suppress(Exception): + from transformers.models.m2m_100.modeling_m2m_100 import M2M100PreTrainedModel + from parallelformers.policies.m2m_100 import M2M100EncoderPolicy, M2M100DecoderPolicy + self.builtin_policies[M2M100PreTrainedModel] = [ M2M100EncoderPolicy, M2M100DecoderPolicy, - ], - MarianPreTrainedModel: [ + ] + + with suppress(Exception): + from transformers.models.marian.modeling_marian import MarianPreTrainedModel + from parallelformers.policies.marian import MarianEncoderPolicy, MarianDecoderPolicy + self.builtin_policies[MarianPreTrainedModel] = [ MarianEncoderPolicy, MarianDecoderPolicy, - ], - MobileBertPreTrainedModel: [MobileBertPolicy], - MPNetPreTrainedModel: [ - MPNetLayerPolicy, + ] + + with suppress(Exception): + from transformers.models.mobilebert.modeling_mobilebert import MobileBertPreTrainedModel + from parallelformers.policies.mobilebert import MobileBertPolicy + self.builtin_policies[MobileBertPreTrainedModel] = [ + MobileBertPolicy, + ] + + with suppress(Exception): + from transformers.models.mpnet.modeling_mpnet import MPNetPreTrainedModel + from parallelformers.policies.mpnet import MPNetEncoderPolicy, MPNetLayerPolicy + self.builtin_policies[MPNetPreTrainedModel] = [ MPNetEncoderPolicy, - ], - LukePreTrainedModel: [LukePolicy], - DPRPretrainedContextEncoder: [BertPolicy], - DPRPretrainedQuestionEncoder: [BertPolicy], - DPRPretrainedReader: [BertPolicy], - LxmertPreTrainedModel: [LxmertPolicy], - HubertPreTrainedModel: [HubertPolicy], - Wav2Vec2PreTrainedModel: [Wav2VecPolicy], - XLNetPreTrainedModel: [XLNetPolicy], - RetriBertPreTrainedModel: [BertPolicy], - CLIPPreTrainedModel: [ + MPNetLayerPolicy, + ] + + with suppress(Exception): + from transformers.models.luke.modeling_luke import LukePreTrainedModel + from parallelformers.policies.luke import LukePolicy + self.builtin_policies[LukePreTrainedModel] = [ + LukePolicy, + ] + + with suppress(Exception): + from transformers.models.dpr.modeling_dpr import DPRPretrainedReader, DPRPretrainedQuestionEncoder, DPRPretrainedContextEncoder + self.builtin_policies[DPRPretrainedReader] = [ + BertPolicy, + ] + + self.builtin_policies[DPRPretrainedQuestionEncoder] = [ + BertPolicy, + ] + + self.builtin_policies[DPRPretrainedContextEncoder] = [ + BertPolicy, + ] + + with suppress(Exception): + from transformers.models.lxmert.modeling_lxmert import LxmertPreTrainedModel + from parallelformers.policies.lxmert import LxmertPolicy + self.builtin_policies[LxmertPreTrainedModel] = [ + LxmertPolicy, + ] + + with suppress(Exception): + from transformers.models.hubert.modeling_hubert import HubertPreTrainedModel + from parallelformers.policies.hubert import HubertPolicy + self.builtin_policies[HubertPreTrainedModel] = [ + HubertPolicy, + ] + + with suppress(Exception): + from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel + from parallelformers.policies.wav2vec import Wav2VecPolicy + self.builtin_policies[Wav2Vec2PreTrainedModel] = [ + Wav2VecPolicy, + ] + + with suppress(Exception): + from transformers.models.xlnet.modeling_xlnet import XLNetPreTrainedModel + from parallelformers.policies.xlnet import XLNetPolicy + self.builtin_policies[XLNetPreTrainedModel] = [ + XLNetPolicy, + ] + + with suppress(Exception): + from transformers.models.retribert.modeling_retribert import RetriBertPreTrainedModel + self.builtin_policies[RetriBertPreTrainedModel] = [ + BertPolicy, + ] + + with suppress(Exception): + from transformers.models.clip.modeling_clip import CLIPPreTrainedModel + from parallelformers.policies.clip import CLIPTextPolicy, CLIPVisionPolicy, CLIPLayerPolicy + self.builtin_policies[CLIPPreTrainedModel] = [ CLIPLayerPolicy, CLIPTextPolicy, CLIPVisionPolicy, - ], - DetrPreTrainedModel: [ + ] + + with suppress(Exception): + from transformers.models.detr.modeling_detr import DetrPreTrainedModel + from parallelformers.policies.detr import DetrEncoderPolicy, DetrDecoderPolicy + self.builtin_policies[DetrPreTrainedModel] = [ DetrEncoderPolicy, DetrDecoderPolicy, - ], - ReformerPreTrainedModel: [ReformerPolicy], - LongformerPreTrainedModel: [LongformerPolicy], - RoFormerPreTrainedModel: [RoformerPolicy], - IBertPreTrainedModel: [IBertPolicy], - TapasPreTrainedModel: [TapasPolicy], - FunnelPreTrainedModel: [FunnelPolicy], - LayoutLMPreTrainedModel: [LayoutLMPolicy], - LEDPreTrainedModel: [ + ] + + with suppress(Exception): + from transformers.models.reformer.modeling_reformer import ReformerPreTrainedModel + from parallelformers.policies.reformer import ReformerPolicy + self.builtin_policies[ReformerPreTrainedModel] = [ + ReformerPolicy, + ] + + with suppress(Exception): + from transformers.models.longformer.modeling_longformer import LongformerPreTrainedModel + from parallelformers.policies.longformer import LongformerPolicy + self.builtin_policies[LongformerPreTrainedModel] = [ + LongformerPolicy, + ] + + with suppress(Exception): + from transformers.models.roformer.modeling_roformer import RoFormerPreTrainedModel + from parallelformers.policies.roformer import RoformerPolicy + self.builtin_policies[RoFormerPreTrainedModel] = [ + RoformerPolicy, + ] + + with suppress(Exception): + from transformers.models.ibert.modeling_ibert import IBertPreTrainedModel + from parallelformers.policies.ibert import IBertPolicy + self.builtin_policies[IBertPreTrainedModel] = [ + IBertPolicy, + ] + + with suppress(Exception): + from transformers.models.tapas.modeling_tapas import TapasPreTrainedModel + from parallelformers.policies.tapas import TapasPolicy + self.builtin_policies[TapasPreTrainedModel] = [ + TapasPolicy, + ] + + with suppress(Exception): + from transformers.models.funnel.modeling_funnel import FunnelPreTrainedModel + from parallelformers.policies.funnel import FunnelPolicy + self.builtin_policies[FunnelPreTrainedModel] = [ + FunnelPolicy, + ] + + with suppress(Exception): + from transformers.models.layoutlm.modeling_layoutlm import LayoutLMPreTrainedModel + from parallelformers.policies.layoutlm import LayoutLMPolicy + self.builtin_policies[LayoutLMPreTrainedModel] = [ + LayoutLMPolicy, + ] + + with suppress(Exception): + from transformers.models.led.modeling_led import LEDPreTrainedModel + from parallelformers.policies.led import LEDEncoderPolicy, LEDDecoderPolicy + self.builtin_policies[LEDPreTrainedModel] = [ LEDEncoderPolicy, LEDDecoderPolicy, - ], - ProphetNetPreTrainedModel: [ + ] + + with suppress(Exception): + from transformers.models.prophetnet.modeling_prophetnet import ProphetNetPreTrainedModel + from parallelformers.policies.prophetnet import ProphetNetEncoderPolicy, ProphetNetDecoderPolicy + self.builtin_policies[ProphetNetPreTrainedModel] = [ ProphetNetEncoderPolicy, ProphetNetDecoderPolicy, - ], - VisualBertPreTrainedModel: [VisualBertPolicy], - Speech2TextPreTrainedModel: [ + ] + + with suppress(Exception): + from transformers.models.visual_bert.modeling_visual_bert import VisualBertPreTrainedModel + from parallelformers.policies.visual_bert import VisualBertPolicy + self.builtin_policies[VisualBertPreTrainedModel] = [ + VisualBertPolicy, + ] + + with suppress(Exception): + from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextPreTrainedModel + from parallelformers.policies.speech_to_text import Speech2TextEncoderPolicy, Speech2TextDecoderPolicy + self.builtin_policies[Speech2TextPreTrainedModel] = [ Speech2TextEncoderPolicy, Speech2TextDecoderPolicy, - ], - } + ] + + def get_policy(self, model: nn.Module) -> Union[List[Policy], None]: + """ + Find appropriate policies for the current model + + Args: + model (nn.Module): model to parallelize + + Returns: + Union[List[Policy], None]: appropriate policies or none + """ + for k, v in self.available().items(): + if isinstance(model, k): + return v + + return None + + def available(self): + """Dictionary of available models and policies""" + return self.builtin_policies diff --git a/setup.py b/setup.py index d70311e..1d0218b 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ from setuptools import setup, find_packages install_requires = [ - 'transformers>=4.8', + 'transformers>=4.2', 'torch', "dacite", "dataclasses;python_version<'3.7'" @@ -26,7 +26,7 @@ setup( name='parallelformers', - version='1.0', + version='1.0.1', description= 'An Efficient Model Parallelization Toolkit for Deployment', long_description=long_description,