Skip to content

Commit

Permalink
Merge pull request SciSharp#845 from martindevans/remove_obsolete
Browse files Browse the repository at this point in the history
Removed some Obsolete APIs
  • Loading branch information
martindevans authored Jul 12, 2024
2 parents c0efbf0 + fe10cf2 commit 1e63bb1
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 76 deletions.
17 changes: 8 additions & 9 deletions LLama/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -682,11 +682,9 @@ public void Save(string path)
Directory.CreateDirectory(path);

string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
var bytes = ContextState?.ToByteArray();
if (bytes is not null)
{
File.WriteAllBytes(modelStateFilePath, bytes);
}
if (ContextState != null)
using (var stateStream = File.Create(modelStateFilePath))
ContextState?.Save(stateStream);

string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState));
Expand Down Expand Up @@ -722,10 +720,11 @@ public static SessionState Load(string path)
throw new ArgumentException("Directory does not exist", nameof(path));
}

string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
var contextState = File.Exists(modelStateFilePath) ?
State.FromByteArray(File.ReadAllBytes(modelStateFilePath))
: null;
var modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
State? contextState = default;
if (File.Exists(modelStateFilePath))
using (var modelStateStream = File.OpenRead(modelStateFilePath))
contextState = State.Load(modelStateStream);

string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
var executorState = JsonSerializer.Deserialize<ExecutorBaseState>(File.ReadAllText(executorStateFilepath));
Expand Down
69 changes: 56 additions & 13 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using LLama.Sampling;
using Microsoft.Extensions.Logging;
using System.Threading;
using System.Security.Cryptography;

namespace LLama
{
Expand Down Expand Up @@ -622,28 +623,70 @@ protected override bool ReleaseHandle()
}

/// <summary>
/// Convert this state to a byte array
/// Write all the bytes of this state to the given stream
/// </summary>
/// <param name="stream"></param>
public async Task SaveAsync(Stream stream)
{
UnmanagedMemoryStream from;
unsafe
{
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size));
}
await from.CopyToAsync(stream);
}

/// <summary>
/// Write all the bytes of this state to the given stream
/// </summary>
/// <param name="stream"></param>
public void Save(Stream stream)
{
UnmanagedMemoryStream from;
unsafe
{
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size));
}
from.CopyTo(stream);
}

/// <summary>
/// Load a state from a stream
/// </summary>
/// <param name="stream"></param>
/// <returns></returns>
[Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")]
public byte[] ToByteArray()
public static async Task<State> LoadAsync(Stream stream)
{
var bytes = new byte[_size];
Marshal.Copy(handle, bytes, 0, (int)_size);
return bytes;
var memory = Marshal.AllocHGlobal((nint)stream.Length);
var state = new State(memory, checked((ulong)stream.Length));

UnmanagedMemoryStream dest;
unsafe
{
dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length);
}
await stream.CopyToAsync(dest);

return state;
}

/// <summary>
/// Load state from a byte array
/// Load a state from a stream
/// </summary>
/// <param name="bytes"></param>
/// <param name="stream"></param>
/// <returns></returns>
[Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")]
public static State FromByteArray(byte[] bytes)
public static State Load(Stream stream)
{
var memory = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, memory, bytes.Length);
return new State(memory, (ulong)bytes.Length);
var memory = Marshal.AllocHGlobal((nint)stream.Length);
var state = new State(memory, checked((ulong)stream.Length));

unsafe
{
var dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length);
stream.CopyTo(dest);
}

return state;
}
}

Expand Down
37 changes: 0 additions & 37 deletions LLama/Native/LLamaTokenType.cs

This file was deleted.

4 changes: 2 additions & 2 deletions LLama/Native/NativeLogConfig.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Runtime.InteropServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -52,7 +52,7 @@ public static void llama_log_set(LLamaLogCallback? logCallback)
{
// We can't set the log method yet since that would cause the llama.dll to load.
// Instead configure it to be set when the native library loading is done
NativeLibraryConfig.Instance.WithLogCallback(logCallback);
NativeLibraryConfig.All.WithLogCallback(logCallback);
}
}

Expand Down
16 changes: 1 addition & 15 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -536,10 +536,7 @@ public unsafe ulong GetState(byte* dest, ulong size)
if (size < required)
throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}");

unsafe
{
return llama_state_get_data(this, dest);
}
return llama_state_get_data(this, dest);
}

/// <summary>
Expand Down Expand Up @@ -589,17 +586,6 @@ public void SetSeed(uint seed)
llama_set_rng_seed(this, seed);
}

/// <summary>
/// Set the number of threads used for decoding
/// </summary>
/// <param name="threads">n_threads is the number of threads used for generation (single token)</param>
/// <param name="threadsBatch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param>
[Obsolete("Use `GenerationThreads` and `BatchThreads` properties")]
public void SetThreads(uint threads, uint threadsBatch)
{
llama_set_n_threads(this, threads, threadsBatch);
}

#region timing
/// <summary>
/// Get performance information
Expand Down

0 comments on commit 1e63bb1

Please sign in to comment.