From b445a6bc77111657a4304a153c94cd87874a60be Mon Sep 17 00:00:00 2001 From: and_gate <38602277+senbeiasano@users.noreply.github.com> Date: Mon, 18 Mar 2024 18:38:08 +0800 Subject: [PATCH] support length penalty (#358) --- .../server/router/model_infer/infer_batch.py | 4 +++- lightllm/server/router/model_infer/model_rpc.py | 7 ++++--- .../server/router/model_infer/post_process.py | 16 ++++++++++++---- lightllm/server/sampling_params.py | 11 ++++++++++- 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index cf9316f9e..cf619da6c 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -4,7 +4,7 @@ import collections from dataclasses import dataclass, field -from typing import List, Dict +from typing import List, Dict, Tuple from lightllm.common.req_manager import ReqManager from lightllm.common.mem_manager import MemoryManager from lightllm.utils.infer_utils import mark_start, mark_end @@ -25,6 +25,7 @@ def __init__( presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, + exponential_decay_length_penalty: Tuple[int, float] = (1, 1.0), temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -34,6 +35,7 @@ def __init__( self.presence_penalty = presence_penalty self.frequency_penalty = frequency_penalty self.repetition_penalty = repetition_penalty + self.exponential_decay_length_penalty = exponential_decay_length_penalty self.temperature = temperature self.top_p = top_p self.top_k = top_k diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index c49fd12ed..a07cbe356 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -100,6 +100,7 @@ def exposed_init_model(self, kvargs): try: self.model_type = model_cfg.get("model_type", "") + self.eos_id = model_cfg.get("eos_token_id", 2) if self.model_type == "bloom": self.model = BloomTpPartModel(model_kvargs) elif self.model_type == "llama": @@ -279,7 +280,7 @@ def forward(self, batch_id, is_prefill): kwargs, run_reqs = prepare_decode_inputs(batch, self.radix_cache, self.model.mem_manager) logits = self.model.forward(**kwargs) - next_token_ids, next_token_probs = sample(logits, run_reqs) + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() @@ -317,7 +318,7 @@ def _prefill_to_return_all_prompt_logprobs(self, batch_id): last_index = torch.cumsum(b_seq_len, dim=0, dtype=torch.long) - 1 logits = prompt_all_logits[last_index, :] - next_token_ids, next_token_probs = sample(logits, run_reqs) + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() @@ -370,7 +371,7 @@ def splitfuse_forward(self, batch_id): all_reqs.extend(prefill_reqs) logits = self.model.splitfuse_forward(**kwargs) - next_token_ids, next_token_probs = sample(logits, all_reqs) + next_token_ids, next_token_probs = sample(logits, all_reqs, self.eos_id) next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() diff --git a/lightllm/server/router/model_infer/post_process.py b/lightllm/server/router/model_infer/post_process.py index 9d36997bd..4b6184d4e 100644 --- a/lightllm/server/router/model_infer/post_process.py +++ b/lightllm/server/router/model_infer/post_process.py @@ -1,14 +1,15 @@ import re import torch -from typing import List +from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import InferBatch from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty -def sample(logits, reqs): +def sample(logits, reqs, eos_id=2): logits = logits.contiguous() - presence_penalties, frequency_penalties, repetition_penalties, temperatures, top_ps, top_ks, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch = _get_post_sample_tensors(reqs) + presence_penalties, frequency_penalties, repetition_penalties, exponential_decay_length_penalties, temperatures, top_ps, top_ks, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch, length_penalty_idx = _get_post_sample_tensors(reqs) apply_penalty(logits, presence_penalties, frequency_penalties, repetition_penalties, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch) + logits[:, eos_id] = logits[:, eos_id] + torch.abs(logits[:, eos_id]) * (torch.pow(exponential_decay_length_penalties, length_penalty_idx).view((-1, 1)) - 1) logits.div_(temperatures.view((-1, 1))) probs = torch.softmax(logits, dim=-1) probs_sort, probs_idx = _top_p_top_k(probs, top_ps, top_ks) @@ -33,6 +34,7 @@ def _get_post_sample_tensors(reqs): presence_penalties: List[float] = [] frequency_penalties: List[float] = [] repetition_penalties: List[float] = [] + exponential_decay_length_penalties: List[float] = [] temperatures: List[float] = [] top_ps: List[float] = [] top_ks: List[int] = [] @@ -40,12 +42,16 @@ def _get_post_sample_tensors(reqs): p_token_counts: List[int] = [] p_seq_len: List[int] = [0,] p_max_len_in_batch: int = 0 + length_penalty_idx: List[int] = [] for i, req_obj in enumerate(reqs): id_to_count = req_obj.out_token_id_count sample_param = req_obj.sampling_param presence_penalties.append(sample_param.presence_penalty) frequency_penalties.append(sample_param.frequency_penalty) repetition_penalties.append(sample_param.repetition_penalty) + exponential_decay_length_penalties.append(sample_param.exponential_decay_length_penalty[1]) + length_penalty_idx.append(max(len(req_obj.input_token_ids) - req_obj.prompt_len - sample_param.exponential_decay_length_penalty[0], 0)) + temperatures.append(sample_param.temperature) top_ps.append(sample_param.top_p) top_ks.append(sample_param.top_k) @@ -59,6 +65,7 @@ def _get_post_sample_tensors(reqs): presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda") frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda") repetition_penalties = torch.tensor(repetition_penalties, dtype=torch.float, device="cuda") + exponential_decay_length_penalties = torch.tensor(exponential_decay_length_penalties, dtype=torch.float, device="cuda") temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda") top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda") top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda") @@ -66,4 +73,5 @@ def _get_post_sample_tensors(reqs): p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda") p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda") p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32) - return presence_penalties, frequency_penalties, repetition_penalties, temperatures, top_ps, top_ks, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch \ No newline at end of file + length_penalty_idx = torch.tensor(length_penalty_idx, dtype=torch.int32, device="cuda") + return presence_penalties, frequency_penalties, repetition_penalties, exponential_decay_length_penalties, temperatures, top_ps, top_ks, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch, length_penalty_idx \ No newline at end of file diff --git a/lightllm/server/sampling_params.py b/lightllm/server/sampling_params.py index e59e349eb..28e1d7a3f 100644 --- a/lightllm/server/sampling_params.py +++ b/lightllm/server/sampling_params.py @@ -1,5 +1,5 @@ """Sampling parameters for text generation.""" -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple _SAMPLING_EPS = 1e-5 @@ -12,6 +12,7 @@ def __init__( presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repetition_penalty: float = 1.0, + exponential_decay_length_penalty: Tuple[int, float] = (1, 1.0), temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, # -1 is for all @@ -23,6 +24,7 @@ def __init__( self.presence_penalty = presence_penalty self.frequency_penalty = frequency_penalty self.repetition_penalty = repetition_penalty + self.exponential_decay_length_penalty = exponential_decay_length_penalty self.temperature = temperature self.top_p = top_p self.top_k = top_k @@ -53,6 +55,12 @@ def verify(self): raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.") if self.max_new_tokens < 1: raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.") + if len(self.exponential_decay_length_penalty) != 2: + raise ValueError(f"exponential_decay_length_penalty must be a tuple of (int, float), got {self.exponential_decay_length_penalty}.") + if not isinstance(self.exponential_decay_length_penalty[0], int) or self.exponential_decay_length_penalty[0] < 0: + raise ValueError(f"exponential_decay_length_penalty[0] must be a non-negative integer, got {self.exponential_decay_length_penalty[0]}.") + if not isinstance(self.exponential_decay_length_penalty[1], float) or self.exponential_decay_length_penalty[1] < 1.0: + raise ValueError(f"exponential_decay_length_penalty[1] must be a float >= 1.0, got {self.exponential_decay_length_penalty[1]}.") return def stop_sentences_to_token_ids(self, tokenizer): @@ -77,6 +85,7 @@ def to_dict(self): ret["presence_penalty"] = self.presence_penalty ret["frequency_penalty"] = self.frequency_penalty ret["repetition_penalty"] = self.repetition_penalty + ret["exponential_decay_length_penalty"] = self.exponential_decay_length_penalty ret["temperature"] = self.temperature ret["top_p"] = self.top_p ret["top_k"] = self.top_k