Skip to content

Commit

Permalink
Easily access all layers using output_layers="all" instead of a tuple.
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Oct 21, 2022
1 parent 842ac7e commit 191b24b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

setuptools.setup(
name="transformers_embedder",
version="3.0.6",
version="3.0.7",
author="Riccardo Orlando",
author_email="orlandoricc@gmail.com",
description="Word level transformer based embeddings",
Expand Down
20 changes: 14 additions & 6 deletions transformers_embedder/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ class TransformersEmbedder(torch.nn.Module):
that is not deterministic. The ``sparse`` strategy is deterministic but it is not comptabile
with ONNX. When ``subword_pooling_strategy`` is ``none``, the sub-word embeddings are not
pooled.
output_layers (`tuple`, optional, defaults to `(-4, -3, -2, -1)`):
Which hidden layers to get from the transformer model.
output_layers (`tuple`, `list`, `str`, optional, defaults to `(-4, -3, -2, -1)`):
Which hidden layers to get from the transformer model. If ``output_layers`` is ``all``,
all the hidden layers are returned. If ``output_layers`` is a tuple or a list, the hidden
layers are selected according to the indexes in the tuple or list. If ``output_layers`` is
a string, it must be ``all``.
fine_tune (`bool`, optional, defaults to `True`):
If ``True``, the transformer model is fine-tuned during training.
return_all (`bool`, optional, defaults to `False`):
Expand All @@ -63,7 +66,7 @@ def __init__(
model: Union[str, tr.PreTrainedModel],
layer_pooling_strategy: str = "last",
subword_pooling_strategy: str = "scatter",
output_layers: Sequence[int] = (-4, -3, -2, -1),
output_layers: Union[Sequence[int], str] = (-4, -3, -2, -1),
fine_tune: bool = True,
return_all: bool = False,
from_pretrained: bool = True,
Expand Down Expand Up @@ -94,9 +97,10 @@ def __init__(
self.layer_pooling_strategy = layer_pooling_strategy
self.subword_pooling_strategy = subword_pooling_strategy

self._scalar_mix: Optional[ScalarMix] = None
if layer_pooling_strategy == "scalar_mix":
self._scalar_mix = ScalarMix(len(output_layers))
if output_layers == "all":
output_layers = tuple(
range(self.transformer_model.config.num_hidden_layers)
)

# check output_layers is well defined
if (
Expand All @@ -110,6 +114,10 @@ def __init__(
)
self.output_layers = output_layers

self._scalar_mix: Optional[ScalarMix] = None
if layer_pooling_strategy == "scalar_mix":
self._scalar_mix = ScalarMix(len(output_layers))

# check if return all transformer outputs
self.return_all = return_all

Expand Down

0 comments on commit 191b24b

Please sign in to comment.