Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring for StatelessExecutor #1084

Open
aropb opened this issue Feb 6, 2025 · 2 comments
Open

Refactoring for StatelessExecutor #1084

aropb opened this issue Feb 6, 2025 · 2 comments

Comments

@aropb
Copy link

aropb commented Feb 6, 2025

Description

Hi,

To present the results of refactoring the StatelessExecutor code.

This looks strange and not optimal (the context can be many gigabytes):

Context = _weights.CreateContext(_params, logger);
Context.Dispose();

    public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
    {
        Images = [];
        _weights = weights;
        _params = @params;
        _logger = logger;
        _batch = new LLamaBatch();

        Context = _weights.CreateContext(_params, logger);
        Context.Dispose();
    }

I don't pretend to be accurate. I use this code and it works:

public class StatelessExecutor : ILLamaExecutor
{
    private readonly LLamaWeights _weights;
    private readonly IContextParams _params;
    private readonly ILogger _logger;
    private readonly LLamaBatch _batch;

    // LLava Section
    public bool IsMultiModal => false;

    /// <inheritdoc />
    public LLavaWeights ClipModel => default;

    /// <inheritdoc />
    public List<byte[]> Images { get; }

    /// <summary>
    /// The context used by the executor when running the inference.
    /// </summary>
    public LLamaContext Context { get; private set; }

    /// <summary>
    /// If true, applies the default template to the prompt as defined in the rules for <a href="https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template">llama_chat_apply_template</a> template.  
    /// </summary>
    public bool ApplyTemplate { get; init; }
    
    /// <summary>
    /// The system message to use with the prompt. Only used when <see cref="ApplyTemplate" /> is true.
    /// </summary>
    public string SystemMessage { get; init; }

    /// <summary>
    /// Create a new stateless executor which will use the given model
    /// </summary>
    /// <param name="weights"></param>
    /// <param name="params"></param>
    /// <param name="logger"></param>
    public LLamaStatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger logger = null)
    {
        Images = [];
        _weights = weights;
        _params = @params;
        _logger = logger;
        _batch = new LLamaBatch();
    }

    /// <inheritdoc />
    public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams inferenceParams, [EnumeratorCancellation] CancellationToken token = default)
    {
        try
        {
            // Create an inference context which will be disposed when this method exits
            using var Context = _weights.CreateContext(_params, _logger);

            // Reset the sampling pipeline (if there is one)
            inferenceParams?.SamplingPipeline?.Reset();

            // Sanity check inference params
            inferenceParams ??= new InferenceParams();
            if (inferenceParams.TokensKeep > Context.ContextSize)
                throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");

            // Create decoders for the token stream
            var decoder = new StreamingTokenDecoder(Context);
            var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts);

            if (ApplyTemplate)
            {
                var template = new LLamaTemplate(_weights.NativeHandle) { AddAssistant = true };
                if (SystemMessage != null) template.Add("system", SystemMessage);

                template.Add("user", text);
                text = PromptTemplateTransformer.ToModelPrompt(template);
            }
            
            // Tokenize the prompt
            var tokens = Context.Tokenize(text: text, addBos: true, special: true).ToList();

            // Evaluate the prompt, in chunks smaller than the max batch size
            var n_past = 0;
            var (r, _, past) = await Context.DecodeAsync(tokens, LLamaSeqId.Zero, _batch, n_past);
            n_past = past;

            if (r != DecodeResult.Ok)
                throw new LLamaDecodeError(r);

            // Begin loop, evaluating one token at a time
            var maxTokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
            for(var i = 0; i < maxTokens && !token.IsCancellationRequested; i++)
            {
                // Sample with the pipeline
                var id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, _batch.TokenCount - 1);

                // Check if this token should end generation
                if (id.IsEndOfGeneration(_weights.Vocab))
                    break;

                // Decode this token into text
                decoder.Add(id);
                var decoded = decoder.Read();
                yield return decoded;

                // Check if any of the antiprompts have been generated
                if (antiprocessor.Add(decoded))
                    break;

                tokens.Clear();
                tokens.Add(id);

                // when run out of context
                // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
                if (n_past + tokens.Count >= Context.ContextSize)
                {
                    var canAddBos = Context.Vocab.ShouldAddBOS;
                    var tokensKeep = inferenceParams.TokensKeep;

                    // number of tokens to keep when resetting context
                    // Ported from https://github.com/ggerganov/llama.cpp/blob/60325fa56f61c228464c9f065db3aa6a61f2156e/examples/main/main.cpp#L334
                    if (tokensKeep < 0 || tokensKeep > tokens.Count)
                    {
                        tokensKeep = tokens.Count;
                    }
                    else
                    {
                        tokensKeep += Convert.ToInt32(canAddBos);
                    }

                    var n_left = n_past - tokensKeep;
                    var n_discard = n_left / 2;

                    NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, LLamaSeqId.Zero, tokensKeep, tokensKeep + n_discard);
                    NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, LLamaSeqId.Zero, tokensKeep + n_discard, n_past, -n_discard);

                    n_past -= n_discard;
                }

                // Evaluate with this new token
                _batch.Clear();
                _batch.Add(id, n_past++, LLamaSeqId.Zero, true);

                var returnCode = await Context.DecodeAsync(_batch, token);
                if (returnCode != 0)
                    throw new LLamaDecodeError(returnCode);
            }
        }
        finally
        {
            Context = null;
        }
    }
}
@sangyuxiaowu
Copy link
Contributor

Yes, it is indeed possible.
#858 (comment)

@aropb
Copy link
Author

aropb commented Feb 6, 2025

#858 (comment)

I always use executor like this:

    public IAsyncEnumerable<GeneratedTextContent> GenerateTextAsync(string prompt, TextGenerationOptions options, CancellationToken cancellationToken = default)
    {
        LLamaStatelessExecutor executor = new(_weights, _modelParams);
        return executor.InferAsync(prompt, _inferenceParams, cancellationToken).Select(x => new GeneratedTextContent(x));
    }

I don't see the point of having a reference to a non-existent context after the constructor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants