Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Retrieval Augmented Generation (RAG) in LLMUnity #246

Open
wants to merge 25 commits into
base: release/v2.3.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1b09078
adapt Search and SearchMethods of RAGSearchUnity
amakropoulos Sep 17, 2024
36008cc
add usearch as thirdparty
amakropoulos Sep 17, 2024
44d45ff
adapt KnowledgeBaseGame sample
amakropoulos Sep 17, 2024
7ed6419
remove setup text
amakropoulos Sep 19, 2024
b840015
check if gguf is only for embeddings
amakropoulos Sep 20, 2024
6d3834a
set embeddings only if the model only supports that
amakropoulos Sep 20, 2024
ae6d852
show error on completion if the model only supports embeddings
amakropoulos Sep 20, 2024
69aa6cd
reduce size of trash icon
amakropoulos Sep 20, 2024
e39ba0d
add magnifier and speech bubble icon
amakropoulos Sep 20, 2024
79760b8
add embedding models, start models with emoji
amakropoulos Sep 20, 2024
9a792de
show images to replace emoji unicode
amakropoulos Sep 20, 2024
54e14b5
add embedding models to ThirdParty
amakropoulos Sep 20, 2024
0c536e5
read embedding length of model
amakropoulos Sep 20, 2024
ce19d4f
use LLM embedding dimensions
amakropoulos Sep 20, 2024
8444871
abstract search methods and add brute force
amakropoulos Sep 20, 2024
e99109e
adapt search to function name changes
amakropoulos Sep 20, 2024
07b2271
remove dlls from ThirdParty folder
amakropoulos Oct 15, 2024
596a374
add filtered search functionality
amakropoulos Oct 15, 2024
8c04013
separate LLMCaller functionality
amakropoulos Oct 15, 2024
1d1999e
reorganise search and text splitter
amakropoulos Oct 15, 2024
523ccb0
implement search methods and search plugins, add sentence splitting p…
amakropoulos Oct 18, 2024
d9b24dd
adapt save / load functions of search
amakropoulos Oct 22, 2024
4100c30
bump to v2.12.0
amakropoulos Oct 22, 2024
bcb01a3
move sentence splitter to different file
amakropoulos Oct 22, 2024
1a1ef4d
adapt editor for LLMCaller and Search
amakropoulos Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions Editor/LLMCharacterEditor.cs → Editor/LLMCallerEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,45 @@

namespace LLMUnity
{
[CustomEditor(typeof(LLMCharacter), true)]
public class LLMCharacterEditor : PropertyEditor
[CustomEditor(typeof(LLMCaller), true)]
public class LLMCallerEditor : PropertyEditor
{
protected override Type[] GetPropertyTypes()
public override void OnInspectorGUI()
{
return new Type[] { typeof(LLMCharacter) };
LLMCaller llmScript = (LLMCaller)target;
SerializedObject llmScriptSO = new SerializedObject(llmScript);

OnInspectorGUIStart(llmScriptSO);
AddOptionsToggles(llmScriptSO);

AddSetupSettings(llmScriptSO);
AddChatSettings(llmScriptSO);
AddModelSettings(llmScriptSO);

OnInspectorGUIEnd(llmScriptSO);
}
}

public void AddModelSettings(SerializedObject llmScriptSO, LLMCharacter llmCharacterScript)
{
EditorGUILayout.LabelField("Model Settings", EditorStyles.boldLabel);
ShowPropertiesOfClass("", llmScriptSO, new List<Type> { typeof(ModelAttribute) }, false);
[CustomEditor(typeof(SearchMethod), true)]
public class SearchMethodEditor : LLMCallerEditor
{
public override void AddChatSettings(SerializedObject llmScriptSO) {}
}

if (llmScriptSO.FindProperty("advancedOptions").boolValue)
[CustomEditor(typeof(LLMCharacter), true)]
public class LLMCharacterEditor : LLMCallerEditor
{
public override void AddModelSettings(SerializedObject llmScriptSO)
{
if (!llmScriptSO.FindProperty("advancedOptions").boolValue)
{
base.AddModelSettings(llmScriptSO);
}
else
{
EditorGUILayout.LabelField("Model Settings", EditorStyles.boldLabel);
ShowPropertiesOfClass("", llmScriptSO, new List<Type> { typeof(ModelAttribute) }, false);

EditorGUILayout.BeginHorizontal();
GUILayout.Label("Grammar", GUILayout.Width(EditorGUIUtility.labelWidth));
if (GUILayout.Button("Load grammar", GUILayout.Width(buttonWidth)))
Expand All @@ -29,7 +53,7 @@ public void AddModelSettings(SerializedObject llmScriptSO, LLMCharacter llmChara
string path = EditorUtility.OpenFilePanelWithFilters("Select a gbnf grammar file", "", new string[] { "Grammar Files", "gbnf" });
if (!string.IsNullOrEmpty(path))
{
llmCharacterScript.SetGrammar(path);
((LLMCharacter)target).SetGrammar(path);
}
};
}
Expand All @@ -38,22 +62,6 @@ public void AddModelSettings(SerializedObject llmScriptSO, LLMCharacter llmChara
ShowPropertiesOfClass("", llmScriptSO, new List<Type> { typeof(ModelAdvancedAttribute) }, false);
}
}

public override void OnInspectorGUI()
{
LLMCharacter llmScript = (LLMCharacter)target;
SerializedObject llmScriptSO = new SerializedObject(llmScript);

OnInspectorGUIStart(llmScriptSO);
AddOptionsToggles(llmScriptSO);

AddSetupSettings(llmScriptSO);
AddChatSettings(llmScriptSO);
Space();
AddModelSettings(llmScriptSO, llmScript);

OnInspectorGUIEnd(llmScriptSO);
}
}

[CustomEditor(typeof(LLMClient))]
Expand Down
File renamed without changes.
57 changes: 46 additions & 11 deletions Editor/LLMEditor.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Threading.Tasks;
using UnityEditor;
Expand All @@ -18,7 +19,8 @@ public class LLMEditor : PropertyEditor
static float includeInBuildColumnWidth = 30f;
static float actionColumnWidth = 20f;
static int elementPadding = 10;
static GUIContent trashIcon;
static GUIContent trashGUIContent;
static Dictionary<string, Texture2D> icons = new Dictionary<string, Texture2D>();
static List<string> modelOptions;
static List<string> modelLicenses;
static List<string> modelURLs;
Expand All @@ -29,11 +31,6 @@ public class LLMEditor : PropertyEditor
bool customURLFocus = false;
bool expandedView = false;

protected override Type[] GetPropertyTypes()
{
return new Type[] { typeof(LLM) };
}

public void AddSecuritySettings(SerializedObject llmScriptSO, LLM llmScript)
{
void AddSSLLoad(string type, Callback<string> setterCallback)
Expand All @@ -55,7 +52,7 @@ void AddSSLInfo(string propertyName, string type, Callback<string> setterCallbac
{
EditorGUILayout.BeginHorizontal();
EditorGUILayout.LabelField("SSL " + type + " path", path);
if (GUILayout.Button(trashIcon, GUILayout.Height(actionColumnWidth), GUILayout.Width(actionColumnWidth))) setterCallback("");
if (GUILayout.Button(trashGUIContent, GUILayout.Height(actionColumnWidth), GUILayout.Width(actionColumnWidth))) setterCallback("");
EditorGUILayout.EndHorizontal();
}
}
Expand Down Expand Up @@ -109,7 +106,7 @@ public void AddModelLoaders(SerializedObject llmScriptSO, LLM llmScript)
if (downloadOnStart != LLMManager.downloadOnStart) LLMManager.SetDownloadOnStart(downloadOnStart);
}

public void AddModelSettings(SerializedObject llmScriptSO)
public override void AddModelSettings(SerializedObject llmScriptSO)
{
List<Type> attributeClasses = new List<Type> { typeof(ModelAttribute) };
if (llmScriptSO.FindProperty("advancedOptions").boolValue)
Expand Down Expand Up @@ -298,11 +295,20 @@ async Task AddLoadButtons()
else await createButtons();
}

void LoadIcons()
{
if (icons.Count > 0) return;
icons["🗑️"] = Resources.Load<Texture2D>("llmunity_trash_icon");
icons["💬"] = Resources.Load<Texture2D>("llmunity_speechballoon_icon");
icons["🔍"] = Resources.Load<Texture2D>("llmunity_magnifier_icon");
}

void OnEnable()
{
LLM llmScript = (LLM)target;
LoadIcons();
ResetModelOptions();
trashIcon = new GUIContent(Resources.Load<Texture2D>("llmunity_trash_icon"), "Delete Model");
trashGUIContent = new GUIContent(icons["🗑️"], "Delete Model");
Texture2D loraLineTexture = new Texture2D(1, 1);
loraLineTexture.SetPixel(0, 0, Color.black);
loraLineTexture.Apply();
Expand Down Expand Up @@ -387,7 +393,7 @@ void OnEnable()
UpdateModels();
}

if (GUI.Button(actionRect, trashIcon))
if (GUI.Button(actionRect, trashGUIContent))
{
if (isSelected)
{
Expand Down Expand Up @@ -426,7 +432,20 @@ void OnEnable()
private void DrawCopyableLabel(Rect rect, string label, string text = "")
{
if (text == "") text = label;
EditorGUI.LabelField(rect, label);
string labelToShow = label;
foreach (var icon in icons)
{
if (StringInfo.GetNextTextElement(label) == icon.Key)
{
float iconSize = rect.height * 3 / 4;
GUI.DrawTexture(new Rect(rect.x, rect.y + (rect.height - iconSize) / 2, iconSize, iconSize), icon.Value);
rect.x += iconSize;
rect.width -= iconSize;
labelToShow = label.Substring(icon.Key.Length);
break;
}
}
EditorGUI.LabelField(rect, labelToShow);
if (Event.current.type == EventType.ContextClick && rect.Contains(Event.current.mousePosition))
{
GenericMenu menu = new GenericMenu();
Expand All @@ -443,6 +462,22 @@ private void CopyToClipboard(string text)
te.Copy();
}

public void AddExtrasToggle()
{
if (ToggleButton("Use extras", LLMUnitySetup.FullLlamaLib)) LLMUnitySetup.SetFullLlamaLib(!LLMUnitySetup.FullLlamaLib);
}

public override void AddOptionsToggles(SerializedObject llmScriptSO)
{
AddDebugModeToggle();

EditorGUILayout.BeginHorizontal();
AddAdvancedOptionsToggle(llmScriptSO);
AddExtrasToggle();
EditorGUILayout.EndHorizontal();
Space();
}

public override void OnInspectorGUI()
{
if (elementFocus != "")
Expand Down
57 changes: 42 additions & 15 deletions Editor/PropertyEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,28 @@
using UnityEngine;
using System.Reflection;
using System.Collections.Generic;
using NUnit.Framework.Internal;

namespace LLMUnity
{
public class PropertyEditor : Editor
{
public static int buttonWidth = 150;

public void AddScript(SerializedObject llmScriptSO)
public virtual void AddScript(SerializedObject llmScriptSO)
{
var scriptProp = llmScriptSO.FindProperty("m_Script");
EditorGUILayout.PropertyField(scriptProp);
}

public bool ToggleButton(string text, bool activated)
public virtual bool ToggleButton(string text, bool activated)
{
GUIStyle style = new GUIStyle("Button");
if (activated) style.normal = new GUIStyleState() { background = Texture2D.grayTexture };
return GUILayout.Button(text, style, GUILayout.Width(buttonWidth));
}

public void AddSetupSettings(SerializedObject llmScriptSO)
public virtual void AddSetupSettings(SerializedObject llmScriptSO)
{
List<Type> attributeClasses = new List<Type>(){typeof(LocalRemoteAttribute)};
attributeClasses.Add(llmScriptSO.FindProperty("remote").boolValue ? typeof(RemoteAttribute) : typeof(LocalAttribute));
Expand All @@ -35,37 +36,62 @@ public void AddSetupSettings(SerializedObject llmScriptSO)
ShowPropertiesOfClass("Setup Settings", llmScriptSO, attributeClasses, true);
}

public void AddChatSettings(SerializedObject llmScriptSO)
public virtual void AddModelSettings(SerializedObject llmScriptSO)
{
List<Type> attributeClasses = new List<Type>(){typeof(ModelAttribute)};
if (llmScriptSO.FindProperty("advancedOptions").boolValue)
{
attributeClasses.Add(typeof(ModelAdvancedAttribute));
}
ShowPropertiesOfClass("Model Settings", llmScriptSO, attributeClasses, true);
}

public virtual void AddChatSettings(SerializedObject llmScriptSO)
{
List<Type> attributeClasses = new List<Type>(){typeof(ChatAttribute)};
if (llmScriptSO.FindProperty("advancedOptions").boolValue)
{
attributeClasses.Add(typeof(ChatAdvancedAttribute));
}
ShowPropertiesOfClass("Chat Settings", llmScriptSO, attributeClasses, false);
ShowPropertiesOfClass("Chat Settings", llmScriptSO, attributeClasses, true);
}

public void AddOptionsToggles(SerializedObject llmScriptSO)
public void AddDebugModeToggle()
{
LLMUnitySetup.SetDebugMode((LLMUnitySetup.DebugModeType)EditorGUILayout.EnumPopup("Log Level", LLMUnitySetup.DebugMode));
}

EditorGUILayout.BeginHorizontal();
public void AddAdvancedOptionsToggle(SerializedObject llmScriptSO)
{
SerializedProperty advancedOptionsProp = llmScriptSO.FindProperty("advancedOptions");
string toggleText = (advancedOptionsProp.boolValue ? "Hide" : "Show") + " Advanced Options";
if (ToggleButton(toggleText, advancedOptionsProp.boolValue)) advancedOptionsProp.boolValue = !advancedOptionsProp.boolValue;
if (ToggleButton("Use extras", LLMUnitySetup.FullLlamaLib)) LLMUnitySetup.SetFullLlamaLib(!LLMUnitySetup.FullLlamaLib);
EditorGUILayout.EndHorizontal();
}

public virtual void AddOptionsToggles(SerializedObject llmScriptSO)
{
AddDebugModeToggle();
AddAdvancedOptionsToggle(llmScriptSO);
Space();
}

public void Space()
public virtual void Space()
{
EditorGUILayout.Space((int)EditorGUIUtility.singleLineHeight / 2);
}

protected virtual Type[] GetPropertyTypes()
{
return new Type[] {};
List<Type> types = new List<Type>();
Type currentType = target.GetType();
while (currentType != null)
{
types.Add(currentType);
currentType = currentType.BaseType;
if (currentType == typeof(MonoBehaviour)) break;
}
types.Reverse();
return types.ToArray();
}

public List<SerializedProperty> GetPropertiesOfClass(SerializedObject so, List<Type> attributeClasses)
Expand All @@ -92,7 +118,7 @@ public List<SerializedProperty> GetPropertiesOfClass(SerializedObject so, List<T
return properties;
}

public void ShowPropertiesOfClass(string title, SerializedObject so, List<Type> attributeClasses, bool addSpace = true, List<Type> excludeAttributeClasses = null)
public bool ShowPropertiesOfClass(string title, SerializedObject so, List<Type> attributeClasses, bool addSpace = true, List<Type> excludeAttributeClasses = null)
{
// display a property if it belongs to a certain class and/or has a specific attribute class
List<SerializedProperty> properties = GetPropertiesOfClass(so, attributeClasses);
Expand All @@ -109,7 +135,7 @@ public void ShowPropertiesOfClass(string title, SerializedObject so, List<Type>
}
foreach (SerializedProperty prop in removeProperties) properties.Remove(prop);
}
if (properties.Count == 0) return;
if (properties.Count == 0) return false;
if (title != "") EditorGUILayout.LabelField(title, EditorStyles.boldLabel);
foreach (SerializedProperty prop in properties)
{
Expand All @@ -129,6 +155,7 @@ public void ShowPropertiesOfClass(string title, SerializedObject so, List<Type>
}
}
if (addSpace) Space();
return true;
}

public bool PropertyInClass(SerializedProperty prop, Type targetClass, Type attributeClass = null)
Expand Down Expand Up @@ -162,7 +189,7 @@ public Attribute GetPropertyAttribute(SerializedProperty prop, Type attributeCla
return null;
}

public void OnInspectorGUIStart(SerializedObject scriptSO)
public virtual void OnInspectorGUIStart(SerializedObject scriptSO)
{
scriptSO.Update();
GUI.enabled = false;
Expand All @@ -171,7 +198,7 @@ public void OnInspectorGUIStart(SerializedObject scriptSO)
EditorGUI.BeginChangeCheck();
}

public void OnInspectorGUIEnd(SerializedObject scriptSO)
public virtual void OnInspectorGUIEnd(SerializedObject scriptSO)
{
if (EditorGUI.EndChangeCheck())
Repaint();
Expand Down
Binary file added Resources/llmunity_magnifier_icon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading