-
Notifications
You must be signed in to change notification settings - Fork 6
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Currently, the sentiment BERT feature is memory and time inefficient. It actually makes it hard to run minimal examples such as this one: https://colab.research.google.com/drive/1y4lIl3aoEFCMTK9Kn7R-M-sOEa_BI8Rz?usp=sharing
(The FB here fails due to running out of RAM on CoLab).
We might want to consider taking some steps to make this more efficient, including:
moving operations to GPU if possible
import torch
def get_sentiment(texts):
"""
Analyzes the sentiment of the given list of texts using a BERT model and returns a DataFrame with scores for positive, negative, and neutral sentiments.
:param texts: The list of input texts to analyze.
:type texts: list of str
:return: A DataFrame with sentiment scores.
:rtype: pd.DataFrame
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_bert.to(device)
texts_series = pd.Series(texts)
non_null_non_empty_texts = texts_series[texts_series.apply(lambda x: pd.notnull(x) and x.strip() != '')].tolist()
if not non_null_non_empty_texts:
return pd.DataFrame(np.nan, index=texts_series.index, columns=['positive_bert', 'negative_bert', 'neutral_bert'])
encoded = tokenizer(non_null_non_empty_texts, padding=True, truncation=True, max_length=512, return_tensors='pt').to(device)
with torch.no_grad():
output = model_bert(**encoded)
scores = output[0].cpu().detach().numpy()
scores = softmax(scores, axis=1)
sent_dict = {
'positive_bert': scores[:, 2],
'negative_bert': scores[:, 0],
'neutral_bert': scores[:, 1]
}
non_null_sent_df = pd.DataFrame(sent_dict)
sent_df = pd.DataFrame(np.nan, index=texts_series.index, columns=['positive_bert', 'negative_bert', 'neutral_bert'])
sent_df.loc[texts_series.apply(lambda x: pd.notnull(x) and x.strip() != ''), ['positive_bert', 'negative_bert', 'neutral_bert']] = non_null_sent_df.values
return sent_df
releasing memory explicitly
del encoded, scores, non_null_sent_df
torch.cuda.empty_cache()
taking a batching approach, like we do with SBERT vectors.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working