Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions atom/entrypoints/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

# Constants
DEFAULT_TEMPERATURE = 1.0
DEFAULT_TOP_K = -1
DEFAULT_TOP_P = 1.0
DEFAULT_MAX_TOKENS = 256
CHAT_COMPLETION_OBJECT = "chat.completion"
Expand Down Expand Up @@ -63,6 +64,7 @@ class ChatCompletionRequest(BaseModel):
messages: Optional[List[ChatMessage]] = None
prompt: Optional[List[ChatMessage]] = None # Accept 'prompt' as alias
temperature: Optional[float] = DEFAULT_TEMPERATURE
top_k: Optional[int] = DEFAULT_TOP_K
top_p: Optional[float] = DEFAULT_TOP_P
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS
stop: Optional[List[str]] = None
Expand All @@ -86,6 +88,7 @@ class CompletionRequest(BaseModel):
model: Optional[str] = None
prompt: str
temperature: Optional[float] = DEFAULT_TEMPERATURE
top_k: Optional[int] = DEFAULT_TOP_K
top_p: Optional[float] = DEFAULT_TOP_P
max_tokens: Optional[int] = DEFAULT_MAX_TOKENS
stop: Optional[List[str]] = None
Expand Down Expand Up @@ -253,9 +256,13 @@ def _build_sampling_params(
max_tokens: int,
stop_strings: Optional[List[str]],
ignore_eos: bool,
top_k: int = -1,
top_p: float = 1.0,
) -> SamplingParams:
return SamplingParams(
temperature=temperature,
top_k=top_k,
top_p=top_p,
max_tokens=max_tokens,
stop_strings=stop_strings,
ignore_eos=ignore_eos,
Expand Down Expand Up @@ -667,6 +674,8 @@ async def chat_completions(request: ChatCompletionRequest):
max_tokens=request.max_tokens,
stop_strings=request.stop,
ignore_eos=request.ignore_eos,
top_k=request.top_k,
top_p=request.top_p,
)

request_id = f"chatcmpl-{uuid.uuid4().hex}"
Expand Down Expand Up @@ -749,6 +758,8 @@ async def completions(request: CompletionRequest):
max_tokens=request.max_tokens,
stop_strings=request.stop,
ignore_eos=request.ignore_eos,
top_k=request.top_k,
top_p=request.top_p,
)

request_id = f"cmpl-{uuid.uuid4().hex}"
Expand Down
46 changes: 38 additions & 8 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,8 @@ def allocate_forward_vars(self):
"input_ids": self.tokenID_processor.input_ids,
"positions": CpuGpuBuffer(self.max_num_batched_tokens, **i64_kwargs),
"temperatures": CpuGpuBuffer(self.max_bs, **f32_kwargs),
"top_ks": CpuGpuBuffer(self.max_bs, **i32_kwargs),
"top_ps": CpuGpuBuffer(self.max_bs, **f32_kwargs),
# Keep enough space for MTP decode (max_q_len > 1).
"outputs": torch.empty(
self.max_num_batched_tokens, hidden_size, dtype=hidden_type
Expand Down Expand Up @@ -1305,23 +1307,45 @@ def prepare_inputs(self, batch: ScheduledBatch, input_ids: torch.Tensor = None):
)
return graph_bs

def prepare_sample(self, batch: ScheduledBatch) -> torch.Tensor:
def prepare_sample(
self, batch: ScheduledBatch
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bs = batch.total_seqs_num
buffer = self.forward_vars["temperatures"]
buffer.np[:bs] = batch.temperatures
return buffer.copy_to_gpu(bs)

temp_buffer = self.forward_vars["temperatures"]
temp_buffer.np[:bs] = batch.temperatures
temperatures = temp_buffer.copy_to_gpu(bs)

# For top_ks and top_ps, check uniformity on CPU before GPU copy.
# If all values are the same, only copy a single element to save bandwidth.
top_k_buffer = self.forward_vars["top_ks"]
top_k_buffer.np[:bs] = batch.top_ks
if bs > 1 and all(k == batch.top_ks[0] for k in batch.top_ks):
top_ks = top_k_buffer.copy_to_gpu(1)
else:
top_ks = top_k_buffer.copy_to_gpu(bs)

top_p_buffer = self.forward_vars["top_ps"]
top_p_buffer.np[:bs] = batch.top_ps
if bs > 1 and all(p == batch.top_ps[0] for p in batch.top_ps):
top_ps = top_p_buffer.copy_to_gpu(1)
else:
top_ps = top_p_buffer.copy_to_gpu(bs)

return temperatures, top_ks, top_ps

def prepare_model(self, batch: ScheduledBatch):
total_tokens_num = batch.total_tokens_num
assert total_tokens_num > 0

temperatures = self.prepare_sample(batch)
temperatures, top_ks, top_ps = self.prepare_sample(batch)
input_ids = self.tokenID_processor.prepare_input_ids(batch)
# self.debug(f"{input_ids=}")
self.prepare_inputs(batch, input_ids)
return (
input_ids,
temperatures,
top_ks,
top_ps,
)

def run_model(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -1348,13 +1372,15 @@ def postprocess(
batch: ScheduledBatch,
logits: torch.Tensor,
temperatures: torch.Tensor,
top_ks: torch.Tensor,
top_ps: torch.Tensor,
# following for draft
hidden_states: torch.Tensor,
) -> ScheduledBatchOutput:
spec_decode_metadata = get_forward_context().spec_decode_metadata
bs = batch.total_seqs_num
if spec_decode_metadata is None:
sampled_tokens = self.sampler(logits, temperatures)
sampled_tokens = self.sampler(logits, temperatures, top_ks, top_ps)
num_reject_tokens = self.tokenID_processor.default_num_rejected_tokens[:bs]
next_token_locs = num_reject_tokens
else:
Expand All @@ -1367,6 +1393,8 @@ def postprocess(
bonus_token_ids = self.sampler(
logits=bonus_logits,
temperatures=temperatures,
top_ks=top_ks,
top_ps=top_ps,
)
# Validate shapes match expectations
if target_logits.shape[0] != len(spec_decode_metadata.draft_token_ids):
Expand Down Expand Up @@ -1429,12 +1457,14 @@ def postprocess(

@torch.inference_mode()
def forward(self, batch: ScheduledBatch) -> ScheduledBatchOutput:
input_ids, temperatures = self.prepare_model(batch)
input_ids, temperatures, top_ks, top_ps = self.prepare_model(batch)
logits, hidden_states = self.run_model(input_ids)
fwd_output = self.postprocess(
batch,
logits,
temperatures,
top_ks,
top_ps,
hidden_states,
)
reset_forward_context()
Expand Down
6 changes: 6 additions & 0 deletions atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ def __init__(
self.mamba_block_tables = [
seq.mamba_block_table for seq in seqs.values() if seq.mamba_block_table
]
self.top_ks = np.asarray(
[seq.top_k for seq in seqs.values()], dtype=np.int32
)
self.top_ps = np.asarray(
[seq.top_p for seq in seqs.values()], dtype=np.float32
)

offs = self.context_lens - self.num_rejected - self.num_scheduled_tokens
self.scheduled_tokens = np.empty(total_tokens_num, dtype=np.int32)
Expand Down
2 changes: 2 additions & 0 deletions atom/model_engine/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(
self.block_table = []
self.mamba_block_table = []
self.temperature = sampling_params.temperature
self.top_k = sampling_params.top_k
self.top_p = sampling_params.top_p
self.max_tokens = sampling_params.max_tokens
self.ignore_eos = sampling_params.ignore_eos
self.stop_strings = sampling_params.stop_strings
Expand Down
Loading
Loading