Skip to content

Commit 9abc856

Browse files
authored
Merge pull request SciSharp#750 from martindevans/minor_batched_example_improvements
Minor BatchedExecutor Example Improvements
2 parents ad6f22c + 6848728 commit 9abc856

File tree

3 files changed

+83
-77
lines changed

3 files changed

+83
-77
lines changed

LLama.Examples/Examples/BatchedExecutorFork.cs

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,20 @@ namespace LLama.Examples.Examples;
1111
/// </summary>
1212
public class BatchedExecutorFork
1313
{
14-
private const int n_split = 16;
15-
private const int n_len = 72;
14+
/// <summary>
15+
/// Set how many tokens to generate before forking
16+
/// </summary>
17+
private const int ForkTokenCount = 16;
18+
19+
/// <summary>
20+
/// Set total length of the sequence to generate
21+
/// </summary>
22+
private const int TokenCount = 72;
1623

1724
public static async Task Run()
1825
{
19-
string modelPath = UserSettings.GetModelPath();
20-
21-
var parameters = new ModelParams(modelPath);
26+
// Load model weights
27+
var parameters = new ModelParams(UserSettings.GetModelPath());
2228
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
2329

2430
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
@@ -27,7 +33,7 @@ public static async Task Run()
2733
using var executor = new BatchedExecutor(model, parameters);
2834

2935
// Print some info
30-
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
36+
var name = model.Metadata.GetValueOrDefault("general.name", "unknown model name");
3137
Console.WriteLine($"Created executor with model: {name}");
3238

3339
// Evaluate the initial prompt to create one conversation
@@ -42,17 +48,17 @@ await AnsiConsole
4248
.Progress()
4349
.StartAsync(async progress =>
4450
{
45-
var reporter = progress.AddTask("Running Inference (1)", maxValue: n_len);
51+
var reporter = progress.AddTask("Running Inference (1)", maxValue: TokenCount);
4652

4753
// Run inference loop
48-
for (var i = 0; i < n_len; i++)
54+
for (var i = 0; i < TokenCount; i++)
4955
{
5056
if (i != 0)
5157
await executor.Infer();
5258

5359
// Occasionally fork all the active conversations
54-
if (i != 0 && i % n_split == 0)
55-
root.Split();
60+
if (i != 0 && i % ForkTokenCount == 0)
61+
root.Fork();
5662

5763
// Sample all active conversations
5864
root.Sample();
@@ -79,8 +85,8 @@ await AnsiConsole
7985
private class Node
8086
{
8187
private readonly StreamingTokenDecoder _decoder;
82-
83-
private readonly DefaultSamplingPipeline _sampler;
88+
89+
private readonly DefaultSamplingPipeline _sampler = new();
8490
private Conversation? _conversation;
8591

8692
private Node? _left;
@@ -90,7 +96,6 @@ private class Node
9096

9197
public Node(Conversation conversation)
9298
{
93-
_sampler = new DefaultSamplingPipeline();
9499
_conversation = conversation;
95100
_decoder = new StreamingTokenDecoder(conversation.Executor.Context);
96101
}
@@ -117,7 +122,7 @@ public void Sample()
117122
_conversation.Prompt(token);
118123
}
119124

120-
public void Split()
125+
public void Fork()
121126
{
122127
if (_conversation != null)
123128
{
@@ -129,8 +134,8 @@ public void Split()
129134
}
130135
else
131136
{
132-
_left?.Split();
133-
_right?.Split();
137+
_left?.Fork();
138+
_right?.Fork();
134139
}
135140
}
136141

LLama.Examples/Examples/BatchedExecutorGuidance.cs

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@ namespace LLama.Examples.Examples;
1212
/// </summary>
1313
public class BatchedExecutorGuidance
1414
{
15-
private const int n_len = 32;
15+
/// <summary>
16+
/// Set how many tokens should be generated
17+
/// </summary>
18+
private const int TokenCount = 32;
1619

1720
public static async Task Run()
1821
{
19-
string modelPath = UserSettings.GetModelPath();
20-
21-
var parameters = new ModelParams(modelPath);
22+
// Load model weights
23+
var parameters = new ModelParams(UserSettings.GetModelPath());
2224
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
2325

2426
var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
@@ -29,7 +31,7 @@ public static async Task Run()
2931
using var executor = new BatchedExecutor(model, parameters);
3032

3133
// Print some info
32-
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
34+
var name = model.Metadata.GetValueOrDefault("general.name", "unknown model name");
3335
Console.WriteLine($"Created executor with model: {name}");
3436

3537
// Load the two prompts into two conversations
@@ -48,30 +50,30 @@ await AnsiConsole
4850
using var unguided = guided.Fork();
4951

5052
// Run inference loop
51-
var unguidedSampler = new GuidedSampler(null, weight);
53+
var unguidedSampler = new DefaultSamplingPipeline();
5254
var unguidedDecoder = new StreamingTokenDecoder(executor.Context);
5355
var guidedSampler = new GuidedSampler(guidance, weight);
5456
var guidedDecoder = new StreamingTokenDecoder(executor.Context);
5557
await AnsiConsole
5658
.Progress()
5759
.StartAsync(async progress =>
5860
{
59-
var reporter = progress.AddTask("Running Inference", maxValue: n_len);
61+
var reporter = progress.AddTask("Running Inference", maxValue: TokenCount);
6062

61-
for (var i = 0; i < n_len; i++)
63+
for (var i = 0; i < TokenCount; i++)
6264
{
6365
if (i != 0)
6466
await executor.Infer();
6567

6668
// Sample from the "unguided" conversation. This is just a conversation using the same prompt, without any
6769
// guidance. This serves as a comparison to show the effect of guidance.
68-
var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample(), Array.Empty<LLamaToken>());
70+
var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample(), []);
6971
unguidedDecoder.Add(u);
7072
unguided.Prompt(u);
7173

7274
// Sample from the "guided" conversation. This sampler will internally use the "guidance" conversation
7375
// to steer the conversation. See how this is done in GuidedSampler.ProcessLogits (bottom of this file).
74-
var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample(), Array.Empty<LLamaToken>());
76+
var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample(), []);
7577
guidedDecoder.Add(g);
7678

7779
// Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt).
@@ -91,37 +93,34 @@ await AnsiConsole
9193
AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read().ReplaceLineEndings(" ")}[/]");
9294
}
9395

94-
private class GuidedSampler(Conversation? guidance, float weight)
96+
private class GuidedSampler(Conversation guidance, float weight)
9597
: BaseSamplingPipeline
9698
{
99+
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
100+
{
101+
// Get the logits generated by the guidance sequences
102+
var guidanceLogits = guidance.Sample();
103+
104+
// Modify these logits based on the guidance logits
105+
candidates.Guidance(ctx, guidanceLogits, weight);
106+
107+
// Basic sampling
108+
candidates.Temperature(ctx, 0.8f);
109+
candidates.TopK(ctx, 25);
110+
return candidates.SampleToken(ctx);
111+
}
112+
97113
public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
98114
{
99115
}
100-
116+
101117
public override ISamplingPipeline Clone()
102118
{
103119
throw new NotSupportedException();
104120
}
105-
121+
106122
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
107123
{
108124
}
109-
110-
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
111-
{
112-
if (guidance != null)
113-
{
114-
// Get the logits generated by the guidance sequences
115-
var guidanceLogits = guidance.Sample();
116-
117-
// Modify these logits based on the guidance logits
118-
candidates.Guidance(ctx, guidanceLogits, weight);
119-
}
120-
121-
candidates.Temperature(ctx, 0.8f);
122-
candidates.TopK(ctx, 25);
123-
124-
return candidates.SampleToken(ctx);
125-
}
126125
}
127126
}

LLama.Examples/Examples/BatchedExecutorRewind.cs

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LLama.Batched;
1+
using LLama.Batched;
22
using LLama.Common;
33
using LLama.Native;
44
using LLama.Sampling;
@@ -11,15 +11,25 @@ namespace LLama.Examples.Examples;
1111
/// </summary>
1212
public class BatchedExecutorRewind
1313
{
14-
private const int n_generate = 24;
15-
private const int n_rewind = 12;
16-
private const int n_repeats = 6;
14+
/// <summary>
15+
/// Set how many tokens to generate before rewinding
16+
/// </summary>
17+
private const int TokensGenerate = 24;
18+
19+
/// <summary>
20+
/// Set how many tokens to rewind
21+
/// </summary>
22+
private const int TokensRewind = 12;
23+
24+
/// <summary>
25+
/// Set how many times to generate and rewind
26+
/// </summary>
27+
private const int RepeatCount = 6;
1728

1829
public static async Task Run()
1930
{
20-
string modelPath = UserSettings.GetModelPath();
21-
22-
var parameters = new ModelParams(modelPath);
31+
// Load model weights
32+
var parameters = new ModelParams(UserSettings.GetModelPath());
2333
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
2434

2535
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
@@ -28,23 +38,23 @@ public static async Task Run()
2838
using var executor = new BatchedExecutor(model, parameters);
2939

3040
// Print some info
31-
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
41+
var name = model.Metadata.GetValueOrDefault("general.name", "unknown model name");
3242
Console.WriteLine($"Created executor with model: {name}");
3343

3444
// Evaluate the initial prompt to create one conversation
3545
using var conversation = executor.Create();
3646
conversation.Prompt(executor.Context.Tokenize(prompt));
3747

3848
// Create the start node wrapping the conversation
39-
var node = new Node(executor.Context);
49+
var node = new Node();
4050

4151
// Print the prompt
4252
Console.ForegroundColor = ConsoleColor.Green;
4353
Console.WriteLine(prompt);
4454

45-
for (var i = 0; i < n_repeats; i++)
55+
for (var i = 0; i < RepeatCount; i++)
4656
{
47-
for (var j = 0; j < n_generate; j++)
57+
for (var j = 0; j < TokensGenerate; j++)
4858
{
4959
// Run inference
5060
await executor.Infer();
@@ -53,21 +63,21 @@ public static async Task Run()
5363
var token = node.Sample(conversation);
5464

5565
// Continue conversation with this token
56-
if (j != n_generate - 1)
66+
if (j != TokensGenerate - 1)
5767
conversation.Prompt(token);
5868
}
5969

6070
// Write out what we generated
61-
node.Write(n_rewind, i + 1);
71+
node.Write(executor.Context, TokensRewind, i + 1);
6272

6373
// Rewind back a few tokens
64-
conversation.Rewind(n_rewind + 1);
74+
conversation.Rewind(TokensRewind + 1);
6575

6676
// Prompt with a token
67-
conversation.Prompt(node.GetToken(n_generate - n_rewind - 1));
77+
conversation.Prompt(node.GetToken(TokensGenerate - TokensRewind - 1));
6878

6979
// Create a new node around the rewound conversation
70-
node = new Node(executor.Context);
80+
node = new Node();
7181
}
7282

7383
Console.WriteLine("Press any key to exit demo");
@@ -76,34 +86,26 @@ public static async Task Run()
7686

7787
private class Node
7888
{
79-
private readonly LLamaContext _context;
80-
81-
private readonly List<LLamaToken> _tokens = new List<LLamaToken>();
82-
private readonly DefaultSamplingPipeline Sampler;
83-
84-
public Node(LLamaContext context)
85-
{
86-
_context = context;
87-
Sampler = new DefaultSamplingPipeline();
88-
}
89+
private readonly List<LLamaToken> _tokens = [ ];
90+
private readonly DefaultSamplingPipeline _sampler = new();
8991

9092
public LLamaToken Sample(Conversation conversation)
9193
{
92-
var token = Sampler.Sample(_context.NativeHandle, conversation.Sample(), Array.Empty<LLamaToken>());
94+
var token = _sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.Sample(), []);
9395
_tokens.Add(token);
9496
return token;
9597
}
9698

97-
public void Write(int n_rewind, int depth)
99+
public void Write(LLamaContext context, int rewind, int depth)
98100
{
99-
var decoder = new StreamingTokenDecoder(_context);
101+
var decoder = new StreamingTokenDecoder(context);
100102

101-
for (var i = 0; i < _tokens.Count - n_rewind; i++)
103+
for (var i = 0; i < _tokens.Count - rewind; i++)
102104
decoder.Add(_tokens[i]);
103105

104106
AnsiConsole.MarkupLine($"[green]{new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")}[/]");
105107

106-
for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++)
108+
for (var i = _tokens.Count - rewind; i < _tokens.Count; i++)
107109
decoder.Add(_tokens[i]);
108110

109111
AnsiConsole.MarkupLine($"[maroon]{decoder.Read().ReplaceLineEndings(" ")}[/]");

0 commit comments

Comments
 (0)