diff --git a/alfred/fm/remote/grpc.py b/alfred/fm/remote/grpc.py index 24faa48..6ef26c6 100644 --- a/alfred/fm/remote/grpc.py +++ b/alfred/fm/remote/grpc.py @@ -4,6 +4,7 @@ import json import logging import socket +import torch.nn.functional as F from concurrent import futures from typing import Optional, Union, Iterable, Tuple, Any, List @@ -122,11 +123,19 @@ def _run_req_gen(): output = [] for response in self.stub.Run(_run_req_gen()): if response.ranked: + logits = ast.literal_eval(response.logit) + candidates = list(logits.keys()) + logit_values = torch.tensor(list(logits.values())) + probabilities = F.softmax(logit_values, dim=0) + scores = { + candidate: prob.item() for candidate, prob in zip(candidates, probabilities) + } output.append( RankedResponse( **{ "prediction": response.message, - "scores": ast.literal_eval(response.logit), + "scores": scores, + "logit": logits, "embeddings": bytes_to_tensor(response.embedding), } ) @@ -241,7 +250,7 @@ def Run(self, request_iterator, context): yield query_pb2.RunResponse( message=response.prediction, ranked=True, - logit=str(response.scores), + logit=str(response.logits), embedding=tensor_to_bytes(response.embeddings), ) else: