Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pulling the chat history into LLMChatHistory #244

Open
wants to merge 7 commits into
base: release/v2.3.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,18 @@ It is also a good idea to enable the `Download on Build` option in the LLM GameO
<details>
<summary>Save / Load your chat history</summary>

To automatically save / load your chat history, you can specify the `Save` parameter of the LLMCharacter to the filename (or relative path) of your choice.
The file is saved in the [persistentDataPath folder of Unity](https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html).
This also saves the state of the LLM which means that the previously cached prompt does not need to be recomputed.
Your `LLMCharacter` components will automatically create corresponding `LLMChatHistory` components to store their chat histories.
- If you don't want to save the chat history, set the `EnableAutoSave` of the `LLMChatHistory` to false.
- You can specify the filename to use by setting the `ChatHistoryFilename` of the `LLMChatHistory`. The file is saved in the [persistentDataPath folder of Unity](https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html).

To manually save your chat history, you can use:
``` c#
llmCharacter.Save("filename");
llmChatHistory.Save();
```
and to load the history:
``` c#
llmCharacter.Load("filename");
llmChatHistory.Load();
```
where filename the filename or relative path of your choice.

</details>
<details>
Expand Down Expand Up @@ -452,8 +451,8 @@ If the user's GPU is not supported, the LLM will fall back to the CPU
- `Port` port of the LLM server (if `Remote` is set)
- `Num Retries` number of HTTP request retries from the LLM server (if `Remote` is set)
- `API key` API key of the LLM server (if `Remote` is set)
- <details><summary><code>Save</code> save filename or relative path</summary> If set, the chat history and LLM state (if save cache is enabled) is automatically saved to file specified. <br> The chat history is saved with a json suffix and the LLM state with a cache suffix. <br> Both files are saved in the [persistentDataPath folder of Unity](https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html).</details>
- `Save Cache` select to save the LLM state along with the chat history. The LLM state is typically around 100MB+.
- <details><summary><code>Cache Filename</code> save filename or relative path</summary> If set, the LLM state (if save cache is enabled) is automatically saved to file specified. <br> The LLM state is saved with a cache suffix. <br> The file is saved in the [persistentDataPath folder of Unity](https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html).</details>
- `Save Cache` select to save the LLM state. The LLM state is typically around 100MB+.
- `Debug Prompt` select to log the constructed prompts in the Unity Editor

#### 🗨️ Chat Settings
Expand Down
233 changes: 109 additions & 124 deletions Runtime/LLMCharacter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ public class LLMCharacter : MonoBehaviour
[Remote] public int numRetries = 10;
/// <summary> allows to use a server with API key </summary>
[Remote] public string APIKey;
/// <summary> file to save the chat history.
/// The file is saved only for Chat calls with addToHistory set to true.
/// The file will be saved within the persistentDataPath directory (see https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html). </summary>
/// <summary> filename to use when saving the cache or chat history. </summary>
[LLM] public string save = "";
/// <summary> toggle to save the LLM cache. This speeds up the prompt calculation but also requires ~100MB of space per character. </summary>
[LLM] public bool saveCache = false;
Expand Down Expand Up @@ -112,6 +110,14 @@ public class LLMCharacter : MonoBehaviour
/// By providing a token ID and a positive or negative bias value, you can increase or decrease the probability of that token being generated. </summary>
public Dictionary<int, string> logitBias = null;

/// <summary> the chat history component that this character uses to store it's chat messages </summary>
public LLMChatHistory chatHistory {
get { return _chatHistory; }
set {
_chatHistory = value;
isCacheInvalid = true;
}
}
/// <summary> the name of the player </summary>
[Chat] public string playerName = "user";
/// <summary> the name of the AI </summary>
Expand All @@ -122,13 +128,14 @@ public class LLMCharacter : MonoBehaviour
public bool setNKeepToPrompt = true;

/// \cond HIDE
public List<ChatMessage> chat;
private SemaphoreSlim chatLock = new SemaphoreSlim(1, 1);
[SerializeField, Chat]
private LLMChatHistory _chatHistory;
private string chatTemplate;
private ChatTemplate template = null;
public string grammarString;
private List<(string, string)> requestHeaders;
private List<UnityWebRequest> WIPRequests = new List<UnityWebRequest>();
private bool isCacheInvalid = false;
/// \endcond

/// <summary>
Expand All @@ -140,7 +147,7 @@ public class LLMCharacter : MonoBehaviour
/// - the chat template is constructed
/// - the number of tokens to keep are based on the system prompt (if setNKeepToPrompt=true)
/// </summary>
public void Awake()
public async void Awake()
{
// Start the LLM server in a cross-platform way
if (!enabled) return;
Expand All @@ -163,7 +170,8 @@ public void Awake()
}

InitGrammar();
InitHistory();
await InitHistory();
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
await LoadCache();
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
}

void OnValidate()
Expand Down Expand Up @@ -215,72 +223,35 @@ void SortBySceneAndHierarchy(LLM[] array)
}
}

protected void InitHistory()
protected async Task InitHistory()
{
InitPrompt();
_ = LoadHistory();
}

protected async Task LoadHistory()
{
if (save == "" || !File.Exists(GetJsonSavePath(save))) return;
await chatLock.WaitAsync(); // Acquire the lock
try
{
await Load(save);
}
finally
{
chatLock.Release(); // Release the lock
// If no specific chat history object has been assigned to this character, create one.
if (chatHistory == null) {
chatHistory = gameObject.AddComponent<LLMChatHistory>();
chatHistory.ChatHistoryFilename = save;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since LLMChatHistory automatically creates a new ChatHistoryFilename if it is empty, this will keep autosaving the history to a new save file.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLMChatHistory only creates a new filename on save so it should be fine I think.

await chatHistory.Load();
}
}

public virtual string GetSavePath(string filename)
public virtual string GetCacheSavePath()
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
{
return Path.Combine(Application.persistentDataPath, filename).Replace('\\', '/');
}

public virtual string GetJsonSavePath(string filename)
{
return GetSavePath(filename + ".json");
}

public virtual string GetCacheSavePath(string filename)
{
return GetSavePath(filename + ".cache");
}

private void InitPrompt(bool clearChat = true)
{
if (chat != null)
{
if (clearChat) chat.Clear();
}
else
{
chat = new List<ChatMessage>();
}
ChatMessage promptMessage = new ChatMessage { role = "system", content = prompt };
if (chat.Count == 0)
{
chat.Add(promptMessage);
}
else
{
chat[0] = promptMessage;
}
return Path.Combine(Application.persistentDataPath, save + ".cache").Replace('\\', '/');
}

/// <summary>
/// Set the system prompt for the LLMCharacter.
/// </summary>
/// <param name="newPrompt"> the system prompt </param>
/// <param name="clearChat"> whether to clear (true) or keep (false) the current chat history on top of the system prompt. </param>
public void SetPrompt(string newPrompt, bool clearChat = true)
public async Task SetPrompt(string newPrompt, bool clearChat = true)
{
prompt = newPrompt;
nKeep = -1;
InitPrompt(clearChat);

if (clearChat) {
// Clear any existing messages
await chatHistory?.Clear();
}
}

private bool CheckTemplate()
Expand All @@ -293,12 +264,16 @@ private bool CheckTemplate()
return true;
}

private ChatMessage GetSystemPromptMessage() {
return new ChatMessage() { role = LLMConstants.SYSTEM_ROLE, content = prompt };
}

private async Task<bool> InitNKeep()
{
if (setNKeepToPrompt && nKeep == -1)
{
if (!CheckTemplate()) return false;
string systemPrompt = template.ComputePrompt(new List<ChatMessage>(){chat[0]}, playerName, "", false);
string systemPrompt = template.ComputePrompt(new List<ChatMessage>(){GetSystemPromptMessage()}, playerName, "", false);
List<int> tokens = await Tokenize(systemPrompt);
if (tokens == null) return false;
SetNKeep(tokens);
Expand Down Expand Up @@ -400,20 +375,19 @@ ChatRequest GenerateRequest(string prompt)
return chatRequest;
}

public void AddMessage(string role, string content)
public async Task AddMessage(string role, string content)
{
// add the question / answer to the chat list, update prompt
chat.Add(new ChatMessage { role = role, content = content });
await chatHistory.AddMessage(role, content);
}

public void AddPlayerMessage(string content)
public async Task AddPlayerMessage(string content)
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
{
AddMessage(playerName, content);
await AddMessage(playerName, content);
}

public void AddAIMessage(string content)
public async Task AddAIMessage(string content)
{
AddMessage(AIName, content);
await AddMessage(AIName, content);
}

protected string ChatContent(ChatResult result)
Expand Down Expand Up @@ -490,44 +464,39 @@ protected string SlotContent(SlotResult result)
/// <returns>the LLM response</returns>
public async Task<string> Chat(string query, Callback<string> callback = null, EmptyCallback completionCallback = null, bool addToHistory = true)
{
// handle a chat message by the user
// call the callback function while the answer is received
// call the completionCallback function when the answer is fully received
await LoadTemplate();
if (!CheckTemplate()) return null;
if (!await InitNKeep()) return null;

var playerMessage = new ChatMessage() { role = playerName, content = query };

string json;
await chatLock.WaitAsync();
try
{
AddPlayerMessage(query);
string prompt = template.ComputePrompt(chat, playerName, AIName);
json = JsonUtility.ToJson(GenerateRequest(prompt));
chat.RemoveAt(chat.Count - 1);
}
finally
{
chatLock.Release();
}
// Setup the full list of messages for the current request
List<ChatMessage> promptMessages = chatHistory ?
await chatHistory.GetChatMessages() :
new List<ChatMessage>();
promptMessages.Insert(0, GetSystemPromptMessage());
promptMessages.Add(playerMessage);

string result = await CompletionRequest(json, callback);
// Prepare the request
string formattedPrompt = template.ComputePrompt(promptMessages, playerName, AIName);
string requestJson = JsonUtility.ToJson(GenerateRequest(formattedPrompt));

// Call the LLM
string result = await CompletionRequest(requestJson, callback);

// Update our chat history if required
if (addToHistory && result != null)
{
await chatLock.WaitAsync();
try
{
AddPlayerMessage(query);
AddAIMessage(result);
}
finally
{
chatLock.Release();
}
if (save != "") _ = Save(save);
await _chatHistory.AddMessages(
new List<ChatMessage> {
new ChatMessage { role = playerName, content = query },
new ChatMessage { role = AIName, content = result }
}
);
}

await SaveCache();

completionCallback?.Invoke();
return result;
}
Expand Down Expand Up @@ -634,46 +603,29 @@ protected async Task<string> Slot(string filepath, string action)
}

/// <summary>
/// Saves the chat history and cache to the provided filename / relative path.
/// Saves the cache to the provided filename / relative path.
/// </summary>
/// <param name="filename">filename / relative path to save the chat history</param>
/// <param name="filename">filename / relative path to save the cache</param>
/// <returns></returns>
public virtual async Task<string> Save(string filename)
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
public virtual async Task<string> SaveCache()
{
string filepath = GetJsonSavePath(filename);
string dirname = Path.GetDirectoryName(filepath);
if (!Directory.Exists(dirname)) Directory.CreateDirectory(dirname);
string json = JsonUtility.ToJson(new ChatListWrapper { chat = chat.GetRange(1, chat.Count - 1) });
File.WriteAllText(filepath, json);

string cachepath = GetCacheSavePath(filename);
if (remote || !saveCache) return null;
string result = await Slot(cachepath, "save");
string result = await Slot(GetCacheSavePath(), "save");
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved

// We now have a valid cache
isCacheInvalid = false;

return result;
}

/// <summary>
/// Load the chat history and cache from the provided filename / relative path.
/// Load the prompt cache.
/// </summary>
/// <param name="filename">filename / relative path to load the chat history from</param>
/// <returns></returns>
public virtual async Task<string> Load(string filename)
public virtual async Task<string> LoadCache()
{
string filepath = GetJsonSavePath(filename);
if (!File.Exists(filepath))
{
LLMUnitySetup.LogError($"File {filepath} does not exist.");
return null;
}
string json = File.ReadAllText(filepath);
List<ChatMessage> chatHistory = JsonUtility.FromJson<ChatListWrapper>(json).chat;
InitPrompt(true);
chat.AddRange(chatHistory);
LLMUnitySetup.Log($"Loaded {filepath}");

string cachepath = GetCacheSavePath(filename);
if (remote || !saveCache || !File.Exists(GetSavePath(cachepath))) return null;
string result = await Slot(cachepath, "restore");
if (remote || !saveCache || isCacheInvalid || !File.Exists(GetCacheSavePath())) return null;

string result = await Slot(GetCacheSavePath(), "restore");
return result;
}

Expand Down Expand Up @@ -843,6 +795,39 @@ protected async Task<Ret> PostRequest<Res, Ret>(string json, string endpoint, Co
if (remote) return await PostRequestRemote(json, endpoint, getContent, callback);
return await PostRequestLocal(json, endpoint, getContent, callback);
}

#region Obsolete Functions

[Obsolete]
public virtual async Task<string> Save(string filename) {

if (chatHistory) {
await chatHistory.Save();
}

return await SaveCache();
}

[Obsolete]
public virtual async Task<string> Load(string filename) {

if (chatHistory) {
chatHistory.ChatHistoryFilename = filename;
await chatHistory.Load();
}

save = filename;
return await LoadCache();
}

[Obsolete]
public virtual string GetSavePath(string filename)
{
return _chatHistory.GetChatHistoryFilePath();
}

#endregion

}

/// \cond HIDE
Expand Down
Loading