Skip to content

Commit

Permalink
⚠️[FIX] Fix OpenAIChatModel /chat/completion output with logits
Browse files Browse the repository at this point in the history
  • Loading branch information
fairyshine committed Sep 28, 2024
1 parent 8971fdd commit 9dd8903
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
20 changes: 11 additions & 9 deletions src/fastmindapi/algo/tree/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@ class MCTSState(ABC):
def __init__(self):
...

def next_state(self):
def next_state(self): # 🌟
...

def terminal(self):
def terminal(self) -> bool: # 🌟
...

def reward(self):
def reward(self): # 🌟
...

def __hash__(self):
def __hash__(self): # 🌟
...

def __eq__(self,other):
...
if hash(self)==hash(other):
return True
return False

def __repr__(self):
...
Expand All @@ -48,7 +50,7 @@ def add_child(self,child_state):
def update(self,reward):
self.reward+=reward
self.visits+=1
def fully_expanded(self, num_moves_lambda):
def fully_expanded(self, num_moves_lambda) -> bool: # 🌟 rewrite
num_moves = self.state.num_moves
if num_moves_lambda is not None:
num_moves = num_moves_lambda(self)
Expand All @@ -61,7 +63,7 @@ def __repr__(self):

class MCTS:
@classmethod
def UCTSearch(cls, budget, root, num_moves_lambda = None):
def UCT_search(cls, budget, root, num_moves_lambda = None):
for iter in range(int(budget)):
if iter%10000==9999:
logger.info("simulation: %d"%iter)
Expand Down Expand Up @@ -139,7 +141,7 @@ def __init__(self, value=0, moves=[], turn=NUM_TURNS):
self.turn=turn
self.moves=moves
def next_state(self):
nextmove=random.choice([x*self.turn for x in self.MOVES])
nextmove=random.choice([x*self.turn for x in self.MOVES])
next=TESTState(self.value+nextmove, self.moves+[nextmove],self.turn-1)
return next
def terminal(self):
Expand All @@ -166,7 +168,7 @@ def __repr__(self):

current_node=MCTSNode(TESTState())
for l in range(args.levels):
current_node=MCTS.UCTSearch(args.num_sims/(l+1),current_node)
current_node=MCTS.UCT_search(args.num_sims/(l+1),current_node)
print("level %d"%l)
print("Num Children: %d"%len(current_node.children))
for i,c in enumerate(current_node.children):
Expand Down
6 changes: 4 additions & 2 deletions src/fastmindapi/model/openai/ChatModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ def generate(self,
],
max_completion_tokens=max_new_tokens,
logprobs=return_logits,
top_logprobs=logits_top_k,
top_logprobs=logits_top_k if return_logits else None,
stop=stop_strings
)
break
except Exception as e:
logger.info(f"【Error】: {e}")
output_text = completion.choices[0].message.content
logits_list = convert_openai_logprobs(completion.choices[0].logprobs)
logits_list = None
if return_logits:
logits_list = convert_openai_logprobs(completion.choices[0].logprobs)
generation_output = {"output_text": output_text,
# "input_id_list": input_id_list,
# "input_token_list": input_token_list,
Expand Down

0 comments on commit 9dd8903

Please sign in to comment.