Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ class GenericLlmRequest
, mLoraConfig(std::move(loraConfig))
, mLookaheadConfig(std::move(lookaheadConfig))
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
, mContextChunkSize{mPromptLen}
, mContextChunkSizeTarget{mPromptLen}
, mContextChunkSizeDraft{mPromptLen}
, mLogProbs(samplingConfig.beamWidth)
, mCumLogProbs(samplingConfig.beamWidth)
, mDraftTokens(draftTokens.value_or(std::make_shared<VecTokens>()))
Expand Down Expand Up @@ -256,7 +257,8 @@ class GenericLlmRequest
, mLoraWeights(std::move(loraWeights))
, mLoraConfig(std::move(loraConfig))
, mLookaheadConfig(lookaheadConfig)
, mContextChunkSize(mPromptLen)
, mContextChunkSizeTarget(mPromptLen)
, mContextChunkSizeDraft(mPromptLen)
, mLogProbs(samplingConfig.beamWidth)
, mCumLogProbs(samplingConfig.beamWidth)
, mDraftTokens(std::make_shared<VecTokens>(draftTokens.value_or(VecTokens())))
Expand Down Expand Up @@ -293,7 +295,8 @@ class GenericLlmRequest
, mOrigPromptLen(mPromptLen)
, mNumPreDecodedTokens(mSamplingConfig.beamWidth, 0)
, mMaxSentTokenLen(mPromptLen)
, mContextChunkSize{mPromptLen}
, mContextChunkSizeTarget{mPromptLen}
, mContextChunkSizeDraft{mPromptLen}
, mLogProbs(mSamplingConfig.beamWidth)
, mCumLogProbs(mSamplingConfig.beamWidth)
, mDraftTokens(std::make_shared<VecTokens>())
Expand Down Expand Up @@ -861,7 +864,8 @@ class GenericLlmRequest
mContextCurrentPositionDraft = 0;
mPrepopulatedPromptLenTarget = 0;
mPrepopulatedPromptLenDraft = 0;
mContextChunkSize = mPromptLen;
mContextChunkSizeTarget = mPromptLen;
mContextChunkSizeDraft = mPromptLen;
mSeqSlot.reset();
}

Expand Down Expand Up @@ -1590,7 +1594,7 @@ class GenericLlmRequest
TLLM_CHECK_WITH_INFO(
isContextInitState() || isDisaggGenerationInitState() || isDisaggGenerationTransmissionComplete(),
"getContextChunkSize is only possible during the context phase or generation init phase.");
return mContextChunkSize;
return mUseDraftModel ? mContextChunkSizeDraft : mContextChunkSizeTarget;
}

/// To set the context chunk size, throw an exception when the chunk size is negative. If the chunk
Expand All @@ -1602,7 +1606,8 @@ class GenericLlmRequest
isContextInitState() || isDisaggGenerationInitState() || isDisaggGenerationTransmissionComplete(),
"setContextChunkSize is only possible during the context phase or generation init phase.");
TLLM_CHECK_WITH_INFO(size >= 0, "The chunk size of context (%d) can't be negative.", size);
mContextChunkSize = std::min(size, getContextRemainingLength());
auto& contextChunkSize = mUseDraftModel ? mContextChunkSizeDraft : mContextChunkSizeTarget;
contextChunkSize = std::min(size, getContextRemainingLength());
}

/// Determines whether the current position is only one chunk away from the end of the context.
Expand All @@ -1625,9 +1630,10 @@ class GenericLlmRequest
{
TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase.");

mContextCurrentPositionDraft += getContextChunkSize();
mContextCurrentPositionTarget += getContextChunkSize();
setContextChunkSize(0);
mContextCurrentPositionDraft += mContextChunkSizeDraft;
mContextCurrentPositionTarget += mContextChunkSizeTarget;
mContextChunkSizeDraft = 0;
mContextChunkSizeTarget = 0;
}

[[nodiscard]] executor::PriorityType priority() const noexcept
Expand Down Expand Up @@ -1987,7 +1993,8 @@ class GenericLlmRequest
// Paged-KV-Cache must be enabled while enabling Chunked-Context.
// The size of the context chunk must be multiple of the KV-Cache block size except the last one.
// Value `0` means Chunked-Context is disabled.
SizeType32 mContextChunkSize{0};
SizeType32 mContextChunkSizeTarget{0};
SizeType32 mContextChunkSizeDraft{0};
SizeType32 mContextCurrentPositionTarget{0};
SizeType32 mContextCurrentPositionDraft{0};

Expand Down