From 69ec78b86be6a1ec365af1d767482fbe4e486264 Mon Sep 17 00:00:00 2001 From: Peilin Yu <54282945+dotpyu@users.noreply.github.com> Date: Fri, 5 Jul 2024 14:45:16 -0400 Subject: [PATCH] =?UTF-8?q?Revert=20"Logit=20Transmission=20via=20gRPC:=20?= =?UTF-8?q?Relocating=20scores=20normalization=20to=20local=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- alfred/fm/remote/grpc.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/alfred/fm/remote/grpc.py b/alfred/fm/remote/grpc.py index 6ef26c6..24faa48 100644 --- a/alfred/fm/remote/grpc.py +++ b/alfred/fm/remote/grpc.py @@ -4,7 +4,6 @@ import json import logging import socket -import torch.nn.functional as F from concurrent import futures from typing import Optional, Union, Iterable, Tuple, Any, List @@ -123,19 +122,11 @@ def _run_req_gen(): output = [] for response in self.stub.Run(_run_req_gen()): if response.ranked: - logits = ast.literal_eval(response.logit) - candidates = list(logits.keys()) - logit_values = torch.tensor(list(logits.values())) - probabilities = F.softmax(logit_values, dim=0) - scores = { - candidate: prob.item() for candidate, prob in zip(candidates, probabilities) - } output.append( RankedResponse( **{ "prediction": response.message, - "scores": scores, - "logit": logits, + "scores": ast.literal_eval(response.logit), "embeddings": bytes_to_tensor(response.embedding), } ) @@ -250,7 +241,7 @@ def Run(self, request_iterator, context): yield query_pb2.RunResponse( message=response.prediction, ranked=True, - logit=str(response.logits), + logit=str(response.scores), embedding=tensor_to_bytes(response.embeddings), ) else: