From 8dd795d2645a483befa4062fa505f35b99e5043c Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Wed, 15 Jan 2025 11:15:25 +0800 Subject: [PATCH] fix AttributeError caused by last_hidden_state --- angle_emb/angle.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/angle_emb/angle.py b/angle_emb/angle.py index d345086..8e0e9bb 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -712,12 +712,13 @@ def __call__(self, :param return_mlm_logits: bool. Return logits or not. Default False. """ if layer_index == -1 and not return_all_layer_outputs: - ret = self.model(**inputs) - outputs = ret.last_hidden_state + ret = self.model(output_hidden_states=True, **inputs) + outputs = ret.last_hidden_state if hasattr(ret, 'last_hidden_state') else ret.hidden_states[-1] else: ret = self.model(output_hidden_states=True, return_dict=True, **inputs) all_layer_outputs = list(ret.hidden_states) - all_layer_outputs[-1] = ret.last_hidden_state + if hasattr(ret, 'last_hidden_state'): + all_layer_outputs[-1] = ret.last_hidden_state if return_all_layer_outputs: return (all_layer_outputs, ret.logits) if return_mlm_logits else all_layer_outputs outputs = all_layer_outputs[layer_index]