Skip to content

Commit fe10cf2

Browse files
committed
Removed some Obsolete APIs:
- `State.ToByteArray` and `State.FromByteArray` were never safe to use, now removed. - Added `State.Load` and `State.Save`, to replace the only uses of `ToByteArray`/`FromByteArray` - Removed `LLamaTokenType` - it's been removed from llama.cpp - Using `All` instead of `Instance` in `NativeLogConfig` - Removed redundant `unsafe` block in `SafeLLamaContextHandle`
1 parent e254ff3 commit fe10cf2

File tree

5 files changed

+67
-76
lines changed

5 files changed

+67
-76
lines changed

LLama/ChatSession.cs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -682,11 +682,9 @@ public void Save(string path)
682682
Directory.CreateDirectory(path);
683683

684684
string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
685-
var bytes = ContextState?.ToByteArray();
686-
if (bytes is not null)
687-
{
688-
File.WriteAllBytes(modelStateFilePath, bytes);
689-
}
685+
if (ContextState != null)
686+
using (var stateStream = File.Create(modelStateFilePath))
687+
ContextState?.Save(stateStream);
690688

691689
string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
692690
File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState));
@@ -722,10 +720,11 @@ public static SessionState Load(string path)
722720
throw new ArgumentException("Directory does not exist", nameof(path));
723721
}
724722

725-
string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
726-
var contextState = File.Exists(modelStateFilePath) ?
727-
State.FromByteArray(File.ReadAllBytes(modelStateFilePath))
728-
: null;
723+
var modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
724+
State? contextState = default;
725+
if (File.Exists(modelStateFilePath))
726+
using (var modelStateStream = File.OpenRead(modelStateFilePath))
727+
contextState = State.Load(modelStateStream);
729728

730729
string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
731730
var executorState = JsonSerializer.Deserialize<ExecutorBaseState>(File.ReadAllText(executorStateFilepath));

LLama/LLamaContext.cs

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
using LLama.Sampling;
1414
using Microsoft.Extensions.Logging;
1515
using System.Threading;
16+
using System.Security.Cryptography;
1617

1718
namespace LLama
1819
{
@@ -622,28 +623,70 @@ protected override bool ReleaseHandle()
622623
}
623624

624625
/// <summary>
625-
/// Convert this state to a byte array
626+
/// Write all the bytes of this state to the given stream
626627
/// </summary>
628+
/// <param name="stream"></param>
629+
public async Task SaveAsync(Stream stream)
630+
{
631+
UnmanagedMemoryStream from;
632+
unsafe
633+
{
634+
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size));
635+
}
636+
await from.CopyToAsync(stream);
637+
}
638+
639+
/// <summary>
640+
/// Write all the bytes of this state to the given stream
641+
/// </summary>
642+
/// <param name="stream"></param>
643+
public void Save(Stream stream)
644+
{
645+
UnmanagedMemoryStream from;
646+
unsafe
647+
{
648+
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size));
649+
}
650+
from.CopyTo(stream);
651+
}
652+
653+
/// <summary>
654+
/// Load a state from a stream
655+
/// </summary>
656+
/// <param name="stream"></param>
627657
/// <returns></returns>
628-
[Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")]
629-
public byte[] ToByteArray()
658+
public static async Task<State> LoadAsync(Stream stream)
630659
{
631-
var bytes = new byte[_size];
632-
Marshal.Copy(handle, bytes, 0, (int)_size);
633-
return bytes;
660+
var memory = Marshal.AllocHGlobal((nint)stream.Length);
661+
var state = new State(memory, checked((ulong)stream.Length));
662+
663+
UnmanagedMemoryStream dest;
664+
unsafe
665+
{
666+
dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length);
667+
}
668+
await stream.CopyToAsync(dest);
669+
670+
return state;
634671
}
635672

636673
/// <summary>
637-
/// Load state from a byte array
674+
/// Load a state from a stream
638675
/// </summary>
639-
/// <param name="bytes"></param>
676+
/// <param name="stream"></param>
640677
/// <returns></returns>
641-
[Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")]
642-
public static State FromByteArray(byte[] bytes)
678+
public static State Load(Stream stream)
643679
{
644-
var memory = Marshal.AllocHGlobal(bytes.Length);
645-
Marshal.Copy(bytes, 0, memory, bytes.Length);
646-
return new State(memory, (ulong)bytes.Length);
680+
var memory = Marshal.AllocHGlobal((nint)stream.Length);
681+
var state = new State(memory, checked((ulong)stream.Length));
682+
683+
unsafe
684+
{
685+
var dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length);
686+
stream.CopyTo(dest);
687+
}
688+
689+
return state;
647690
}
648691
}
649692

LLama/Native/LLamaTokenType.cs

Lines changed: 0 additions & 37 deletions
This file was deleted.

LLama/Native/NativeLogConfig.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System.Runtime.InteropServices;
1+
using System.Runtime.InteropServices;
22
using System.Text;
33
using System.Threading;
44
using Microsoft.Extensions.Logging;
@@ -52,7 +52,7 @@ public static void llama_log_set(LLamaLogCallback? logCallback)
5252
{
5353
// We can't set the log method yet since that would cause the llama.dll to load.
5454
// Instead configure it to be set when the native library loading is done
55-
NativeLibraryConfig.Instance.WithLogCallback(logCallback);
55+
NativeLibraryConfig.All.WithLogCallback(logCallback);
5656
}
5757
}
5858

LLama/Native/SafeLLamaContextHandle.cs

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,7 @@ public unsafe ulong GetState(byte* dest, ulong size)
536536
if (size < required)
537537
throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}");
538538

539-
unsafe
540-
{
541-
return llama_state_get_data(this, dest);
542-
}
539+
return llama_state_get_data(this, dest);
543540
}
544541

545542
/// <summary>
@@ -589,17 +586,6 @@ public void SetSeed(uint seed)
589586
llama_set_rng_seed(this, seed);
590587
}
591588

592-
/// <summary>
593-
/// Set the number of threads used for decoding
594-
/// </summary>
595-
/// <param name="threads">n_threads is the number of threads used for generation (single token)</param>
596-
/// <param name="threadsBatch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param>
597-
[Obsolete("Use `GenerationThreads` and `BatchThreads` properties")]
598-
public void SetThreads(uint threads, uint threadsBatch)
599-
{
600-
llama_set_n_threads(this, threads, threadsBatch);
601-
}
602-
603589
#region timing
604590
/// <summary>
605591
/// Get performance information

0 commit comments

Comments
 (0)