diff --git a/alfred/fm/remote/grpc.py b/alfred/fm/remote/grpc.py index 6ef26c6..24faa48 100644 --- a/alfred/fm/remote/grpc.py +++ b/alfred/fm/remote/grpc.py @@ -4,7 +4,6 @@ import json import logging import socket -import torch.nn.functional as F from concurrent import futures from typing import Optional, Union, Iterable, Tuple, Any, List @@ -123,19 +122,11 @@ def _run_req_gen(): output = [] for response in self.stub.Run(_run_req_gen()): if response.ranked: - logits = ast.literal_eval(response.logit) - candidates = list(logits.keys()) - logit_values = torch.tensor(list(logits.values())) - probabilities = F.softmax(logit_values, dim=0) - scores = { - candidate: prob.item() for candidate, prob in zip(candidates, probabilities) - } output.append( RankedResponse( **{ "prediction": response.message, - "scores": scores, - "logit": logits, + "scores": ast.literal_eval(response.logit), "embeddings": bytes_to_tensor(response.embedding), } ) @@ -250,7 +241,7 @@ def Run(self, request_iterator, context): yield query_pb2.RunResponse( message=response.prediction, ranked=True, - logit=str(response.logits), + logit=str(response.scores), embedding=tensor_to_bytes(response.embeddings), ) else: