diff --git a/Editor/LLMCharacterEditor.cs b/Editor/LLMCallerEditor.cs similarity index 60% rename from Editor/LLMCharacterEditor.cs rename to Editor/LLMCallerEditor.cs index 0111dcb4..b63913af 100644 --- a/Editor/LLMCharacterEditor.cs +++ b/Editor/LLMCallerEditor.cs @@ -5,21 +5,45 @@ namespace LLMUnity { - [CustomEditor(typeof(LLMCharacter), true)] - public class LLMCharacterEditor : PropertyEditor + [CustomEditor(typeof(LLMCaller), true)] + public class LLMCallerEditor : PropertyEditor { - protected override Type[] GetPropertyTypes() + public override void OnInspectorGUI() { - return new Type[] { typeof(LLMCharacter) }; + LLMCaller llmScript = (LLMCaller)target; + SerializedObject llmScriptSO = new SerializedObject(llmScript); + + OnInspectorGUIStart(llmScriptSO); + AddOptionsToggles(llmScriptSO); + + AddSetupSettings(llmScriptSO); + AddChatSettings(llmScriptSO); + AddModelSettings(llmScriptSO); + + OnInspectorGUIEnd(llmScriptSO); } + } - public void AddModelSettings(SerializedObject llmScriptSO, LLMCharacter llmCharacterScript) - { - EditorGUILayout.LabelField("Model Settings", EditorStyles.boldLabel); - ShowPropertiesOfClass("", llmScriptSO, new List { typeof(ModelAttribute) }, false); + [CustomEditor(typeof(SearchMethod), true)] + public class SearchMethodEditor : LLMCallerEditor + { + public override void AddChatSettings(SerializedObject llmScriptSO) {} + } - if (llmScriptSO.FindProperty("advancedOptions").boolValue) + [CustomEditor(typeof(LLMCharacter), true)] + public class LLMCharacterEditor : LLMCallerEditor + { + public override void AddModelSettings(SerializedObject llmScriptSO) + { + if (!llmScriptSO.FindProperty("advancedOptions").boolValue) + { + base.AddModelSettings(llmScriptSO); + } + else { + EditorGUILayout.LabelField("Model Settings", EditorStyles.boldLabel); + ShowPropertiesOfClass("", llmScriptSO, new List { typeof(ModelAttribute) }, false); + EditorGUILayout.BeginHorizontal(); GUILayout.Label("Grammar", GUILayout.Width(EditorGUIUtility.labelWidth)); if (GUILayout.Button("Load grammar", GUILayout.Width(buttonWidth))) @@ -29,7 +53,7 @@ public void AddModelSettings(SerializedObject llmScriptSO, LLMCharacter llmChara string path = EditorUtility.OpenFilePanelWithFilters("Select a gbnf grammar file", "", new string[] { "Grammar Files", "gbnf" }); if (!string.IsNullOrEmpty(path)) { - llmCharacterScript.SetGrammar(path); + ((LLMCharacter)target).SetGrammar(path); } }; } @@ -38,22 +62,6 @@ public void AddModelSettings(SerializedObject llmScriptSO, LLMCharacter llmChara ShowPropertiesOfClass("", llmScriptSO, new List { typeof(ModelAdvancedAttribute) }, false); } } - - public override void OnInspectorGUI() - { - LLMCharacter llmScript = (LLMCharacter)target; - SerializedObject llmScriptSO = new SerializedObject(llmScript); - - OnInspectorGUIStart(llmScriptSO); - AddOptionsToggles(llmScriptSO); - - AddSetupSettings(llmScriptSO); - AddChatSettings(llmScriptSO); - Space(); - AddModelSettings(llmScriptSO, llmScript); - - OnInspectorGUIEnd(llmScriptSO); - } } [CustomEditor(typeof(LLMClient))] diff --git a/Editor/LLMCharacterEditor.cs.meta b/Editor/LLMCallerEditor.cs.meta similarity index 100% rename from Editor/LLMCharacterEditor.cs.meta rename to Editor/LLMCallerEditor.cs.meta diff --git a/Editor/LLMEditor.cs b/Editor/LLMEditor.cs index 38547dcd..6ad8a8a0 100644 --- a/Editor/LLMEditor.cs +++ b/Editor/LLMEditor.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Globalization; using System.Linq; using System.Threading.Tasks; using UnityEditor; @@ -18,7 +19,8 @@ public class LLMEditor : PropertyEditor static float includeInBuildColumnWidth = 30f; static float actionColumnWidth = 20f; static int elementPadding = 10; - static GUIContent trashIcon; + static GUIContent trashGUIContent; + static Dictionary icons = new Dictionary(); static List modelOptions; static List modelLicenses; static List modelURLs; @@ -29,11 +31,6 @@ public class LLMEditor : PropertyEditor bool customURLFocus = false; bool expandedView = false; - protected override Type[] GetPropertyTypes() - { - return new Type[] { typeof(LLM) }; - } - public void AddSecuritySettings(SerializedObject llmScriptSO, LLM llmScript) { void AddSSLLoad(string type, Callback setterCallback) @@ -55,7 +52,7 @@ void AddSSLInfo(string propertyName, string type, Callback setterCallbac { EditorGUILayout.BeginHorizontal(); EditorGUILayout.LabelField("SSL " + type + " path", path); - if (GUILayout.Button(trashIcon, GUILayout.Height(actionColumnWidth), GUILayout.Width(actionColumnWidth))) setterCallback(""); + if (GUILayout.Button(trashGUIContent, GUILayout.Height(actionColumnWidth), GUILayout.Width(actionColumnWidth))) setterCallback(""); EditorGUILayout.EndHorizontal(); } } @@ -109,7 +106,7 @@ public void AddModelLoaders(SerializedObject llmScriptSO, LLM llmScript) if (downloadOnStart != LLMManager.downloadOnStart) LLMManager.SetDownloadOnStart(downloadOnStart); } - public void AddModelSettings(SerializedObject llmScriptSO) + public override void AddModelSettings(SerializedObject llmScriptSO) { List attributeClasses = new List { typeof(ModelAttribute) }; if (llmScriptSO.FindProperty("advancedOptions").boolValue) @@ -298,11 +295,20 @@ async Task AddLoadButtons() else await createButtons(); } + void LoadIcons() + { + if (icons.Count > 0) return; + icons["πŸ—‘οΈ"] = Resources.Load("llmunity_trash_icon"); + icons["πŸ’¬"] = Resources.Load("llmunity_speechballoon_icon"); + icons["πŸ”"] = Resources.Load("llmunity_magnifier_icon"); + } + void OnEnable() { LLM llmScript = (LLM)target; + LoadIcons(); ResetModelOptions(); - trashIcon = new GUIContent(Resources.Load("llmunity_trash_icon"), "Delete Model"); + trashGUIContent = new GUIContent(icons["πŸ—‘οΈ"], "Delete Model"); Texture2D loraLineTexture = new Texture2D(1, 1); loraLineTexture.SetPixel(0, 0, Color.black); loraLineTexture.Apply(); @@ -387,7 +393,7 @@ void OnEnable() UpdateModels(); } - if (GUI.Button(actionRect, trashIcon)) + if (GUI.Button(actionRect, trashGUIContent)) { if (isSelected) { @@ -426,7 +432,20 @@ void OnEnable() private void DrawCopyableLabel(Rect rect, string label, string text = "") { if (text == "") text = label; - EditorGUI.LabelField(rect, label); + string labelToShow = label; + foreach (var icon in icons) + { + if (StringInfo.GetNextTextElement(label) == icon.Key) + { + float iconSize = rect.height * 3 / 4; + GUI.DrawTexture(new Rect(rect.x, rect.y + (rect.height - iconSize) / 2, iconSize, iconSize), icon.Value); + rect.x += iconSize; + rect.width -= iconSize; + labelToShow = label.Substring(icon.Key.Length); + break; + } + } + EditorGUI.LabelField(rect, labelToShow); if (Event.current.type == EventType.ContextClick && rect.Contains(Event.current.mousePosition)) { GenericMenu menu = new GenericMenu(); @@ -443,6 +462,22 @@ private void CopyToClipboard(string text) te.Copy(); } + public void AddExtrasToggle() + { + if (ToggleButton("Use extras", LLMUnitySetup.FullLlamaLib)) LLMUnitySetup.SetFullLlamaLib(!LLMUnitySetup.FullLlamaLib); + } + + public override void AddOptionsToggles(SerializedObject llmScriptSO) + { + AddDebugModeToggle(); + + EditorGUILayout.BeginHorizontal(); + AddAdvancedOptionsToggle(llmScriptSO); + AddExtrasToggle(); + EditorGUILayout.EndHorizontal(); + Space(); + } + public override void OnInspectorGUI() { if (elementFocus != "") diff --git a/Editor/PropertyEditor.cs b/Editor/PropertyEditor.cs index d3510ac8..d6e078ed 100644 --- a/Editor/PropertyEditor.cs +++ b/Editor/PropertyEditor.cs @@ -3,6 +3,7 @@ using UnityEngine; using System.Reflection; using System.Collections.Generic; +using NUnit.Framework.Internal; namespace LLMUnity { @@ -10,20 +11,20 @@ public class PropertyEditor : Editor { public static int buttonWidth = 150; - public void AddScript(SerializedObject llmScriptSO) + public virtual void AddScript(SerializedObject llmScriptSO) { var scriptProp = llmScriptSO.FindProperty("m_Script"); EditorGUILayout.PropertyField(scriptProp); } - public bool ToggleButton(string text, bool activated) + public virtual bool ToggleButton(string text, bool activated) { GUIStyle style = new GUIStyle("Button"); if (activated) style.normal = new GUIStyleState() { background = Texture2D.grayTexture }; return GUILayout.Button(text, style, GUILayout.Width(buttonWidth)); } - public void AddSetupSettings(SerializedObject llmScriptSO) + public virtual void AddSetupSettings(SerializedObject llmScriptSO) { List attributeClasses = new List(){typeof(LocalRemoteAttribute)}; attributeClasses.Add(llmScriptSO.FindProperty("remote").boolValue ? typeof(RemoteAttribute) : typeof(LocalAttribute)); @@ -35,37 +36,62 @@ public void AddSetupSettings(SerializedObject llmScriptSO) ShowPropertiesOfClass("Setup Settings", llmScriptSO, attributeClasses, true); } - public void AddChatSettings(SerializedObject llmScriptSO) + public virtual void AddModelSettings(SerializedObject llmScriptSO) + { + List attributeClasses = new List(){typeof(ModelAttribute)}; + if (llmScriptSO.FindProperty("advancedOptions").boolValue) + { + attributeClasses.Add(typeof(ModelAdvancedAttribute)); + } + ShowPropertiesOfClass("Model Settings", llmScriptSO, attributeClasses, true); + } + + public virtual void AddChatSettings(SerializedObject llmScriptSO) { List attributeClasses = new List(){typeof(ChatAttribute)}; if (llmScriptSO.FindProperty("advancedOptions").boolValue) { attributeClasses.Add(typeof(ChatAdvancedAttribute)); } - ShowPropertiesOfClass("Chat Settings", llmScriptSO, attributeClasses, false); + ShowPropertiesOfClass("Chat Settings", llmScriptSO, attributeClasses, true); } - public void AddOptionsToggles(SerializedObject llmScriptSO) + public void AddDebugModeToggle() { LLMUnitySetup.SetDebugMode((LLMUnitySetup.DebugModeType)EditorGUILayout.EnumPopup("Log Level", LLMUnitySetup.DebugMode)); + } - EditorGUILayout.BeginHorizontal(); + public void AddAdvancedOptionsToggle(SerializedObject llmScriptSO) + { SerializedProperty advancedOptionsProp = llmScriptSO.FindProperty("advancedOptions"); string toggleText = (advancedOptionsProp.boolValue ? "Hide" : "Show") + " Advanced Options"; if (ToggleButton(toggleText, advancedOptionsProp.boolValue)) advancedOptionsProp.boolValue = !advancedOptionsProp.boolValue; - if (ToggleButton("Use extras", LLMUnitySetup.FullLlamaLib)) LLMUnitySetup.SetFullLlamaLib(!LLMUnitySetup.FullLlamaLib); - EditorGUILayout.EndHorizontal(); + } + + public virtual void AddOptionsToggles(SerializedObject llmScriptSO) + { + AddDebugModeToggle(); + AddAdvancedOptionsToggle(llmScriptSO); Space(); } - public void Space() + public virtual void Space() { EditorGUILayout.Space((int)EditorGUIUtility.singleLineHeight / 2); } protected virtual Type[] GetPropertyTypes() { - return new Type[] {}; + List types = new List(); + Type currentType = target.GetType(); + while (currentType != null) + { + types.Add(currentType); + currentType = currentType.BaseType; + if (currentType == typeof(MonoBehaviour)) break; + } + types.Reverse(); + return types.ToArray(); } public List GetPropertiesOfClass(SerializedObject so, List attributeClasses) @@ -92,7 +118,7 @@ public List GetPropertiesOfClass(SerializedObject so, List attributeClasses, bool addSpace = true, List excludeAttributeClasses = null) + public bool ShowPropertiesOfClass(string title, SerializedObject so, List attributeClasses, bool addSpace = true, List excludeAttributeClasses = null) { // display a property if it belongs to a certain class and/or has a specific attribute class List properties = GetPropertiesOfClass(so, attributeClasses); @@ -109,7 +135,7 @@ public void ShowPropertiesOfClass(string title, SerializedObject so, List } foreach (SerializedProperty prop in removeProperties) properties.Remove(prop); } - if (properties.Count == 0) return; + if (properties.Count == 0) return false; if (title != "") EditorGUILayout.LabelField(title, EditorStyles.boldLabel); foreach (SerializedProperty prop in properties) { @@ -129,6 +155,7 @@ public void ShowPropertiesOfClass(string title, SerializedObject so, List } } if (addSpace) Space(); + return true; } public bool PropertyInClass(SerializedProperty prop, Type targetClass, Type attributeClass = null) @@ -162,7 +189,7 @@ public Attribute GetPropertyAttribute(SerializedProperty prop, Type attributeCla return null; } - public void OnInspectorGUIStart(SerializedObject scriptSO) + public virtual void OnInspectorGUIStart(SerializedObject scriptSO) { scriptSO.Update(); GUI.enabled = false; @@ -171,7 +198,7 @@ public void OnInspectorGUIStart(SerializedObject scriptSO) EditorGUI.BeginChangeCheck(); } - public void OnInspectorGUIEnd(SerializedObject scriptSO) + public virtual void OnInspectorGUIEnd(SerializedObject scriptSO) { if (EditorGUI.EndChangeCheck()) Repaint(); diff --git a/Resources/llmunity_magnifier_icon.png b/Resources/llmunity_magnifier_icon.png new file mode 100644 index 00000000..422bae4d Binary files /dev/null and b/Resources/llmunity_magnifier_icon.png differ diff --git a/Resources/llmunity_magnifier_icon.png.meta b/Resources/llmunity_magnifier_icon.png.meta new file mode 100644 index 00000000..1bddf249 --- /dev/null +++ b/Resources/llmunity_magnifier_icon.png.meta @@ -0,0 +1,140 @@ +fileFormatVersion: 2 +guid: fe6761bddab789db4acea6739a13ad88 +TextureImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 12 + mipmaps: + mipMapMode: 0 + enableMipMap: 1 + sRGBTexture: 1 + linearTexture: 0 + fadeOut: 0 + borderMipMap: 0 + mipMapsPreserveCoverage: 0 + alphaTestReferenceValue: 0.5 + mipMapFadeDistanceStart: 1 + mipMapFadeDistanceEnd: 3 + bumpmap: + convertToNormalMap: 0 + externalNormalMap: 0 + heightScale: 0.25 + normalMapFilter: 0 + flipGreenChannel: 0 + isReadable: 0 + streamingMipmaps: 0 + streamingMipmapsPriority: 0 + vTOnly: 0 + ignoreMipmapLimit: 0 + grayScaleToAlpha: 0 + generateCubemap: 6 + cubemapConvolution: 0 + seamlessCubemap: 0 + textureFormat: 1 + maxTextureSize: 2048 + textureSettings: + serializedVersion: 2 + filterMode: 1 + aniso: 1 + mipBias: 0 + wrapU: 0 + wrapV: 0 + wrapW: 0 + nPOTScale: 1 + lightmap: 0 + compressionQuality: 50 + spriteMode: 0 + spriteExtrude: 1 + spriteMeshType: 1 + alignment: 0 + spritePivot: {x: 0.5, y: 0.5} + spritePixelsToUnits: 100 + spriteBorder: {x: 0, y: 0, z: 0, w: 0} + spriteGenerateFallbackPhysicsShape: 1 + alphaUsage: 1 + alphaIsTransparency: 0 + spriteTessellationDetail: -1 + textureType: 0 + textureShape: 1 + singleChannelComponent: 0 + flipbookRows: 1 + flipbookColumns: 1 + maxTextureSizeSet: 0 + compressionQualitySet: 0 + textureFormatSet: 0 + ignorePngGamma: 0 + applyGammaDecoding: 0 + swizzle: 50462976 + cookieLightType: 0 + platformSettings: + - serializedVersion: 3 + buildTarget: DefaultTexturePlatform + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + ignorePlatformSupport: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + - serializedVersion: 3 + buildTarget: Standalone + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + ignorePlatformSupport: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + - serializedVersion: 3 + buildTarget: Android + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + ignorePlatformSupport: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + - serializedVersion: 3 + buildTarget: Server + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + ignorePlatformSupport: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + spriteSheet: + serializedVersion: 2 + sprites: [] + outline: [] + physicsShape: [] + bones: [] + spriteID: + internalID: 0 + vertices: [] + indices: + edges: [] + weights: [] + secondaryTextures: [] + nameFileIdTable: {} + mipmapLimitGroupName: + pSDRemoveMatte: 0 + userData: + assetBundleName: + assetBundleVariant: diff --git a/Resources/llmunity_speechballoon_icon.png b/Resources/llmunity_speechballoon_icon.png new file mode 100644 index 00000000..2e975542 Binary files /dev/null and b/Resources/llmunity_speechballoon_icon.png differ diff --git a/Resources/llmunity_speechballoon_icon.png.meta b/Resources/llmunity_speechballoon_icon.png.meta new file mode 100644 index 00000000..3df410c3 --- /dev/null +++ b/Resources/llmunity_speechballoon_icon.png.meta @@ -0,0 +1,140 @@ +fileFormatVersion: 2 +guid: 35f8310efe5c82e9285174efa542947a +TextureImporter: + internalIDToNameTable: [] + externalObjects: {} + serializedVersion: 12 + mipmaps: + mipMapMode: 0 + enableMipMap: 1 + sRGBTexture: 1 + linearTexture: 0 + fadeOut: 0 + borderMipMap: 0 + mipMapsPreserveCoverage: 0 + alphaTestReferenceValue: 0.5 + mipMapFadeDistanceStart: 1 + mipMapFadeDistanceEnd: 3 + bumpmap: + convertToNormalMap: 0 + externalNormalMap: 0 + heightScale: 0.25 + normalMapFilter: 0 + flipGreenChannel: 0 + isReadable: 0 + streamingMipmaps: 0 + streamingMipmapsPriority: 0 + vTOnly: 0 + ignoreMipmapLimit: 0 + grayScaleToAlpha: 0 + generateCubemap: 6 + cubemapConvolution: 0 + seamlessCubemap: 0 + textureFormat: 1 + maxTextureSize: 2048 + textureSettings: + serializedVersion: 2 + filterMode: 1 + aniso: 1 + mipBias: 0 + wrapU: 0 + wrapV: 0 + wrapW: 0 + nPOTScale: 1 + lightmap: 0 + compressionQuality: 50 + spriteMode: 0 + spriteExtrude: 1 + spriteMeshType: 1 + alignment: 0 + spritePivot: {x: 0.5, y: 0.5} + spritePixelsToUnits: 100 + spriteBorder: {x: 0, y: 0, z: 0, w: 0} + spriteGenerateFallbackPhysicsShape: 1 + alphaUsage: 1 + alphaIsTransparency: 0 + spriteTessellationDetail: -1 + textureType: 0 + textureShape: 1 + singleChannelComponent: 0 + flipbookRows: 1 + flipbookColumns: 1 + maxTextureSizeSet: 0 + compressionQualitySet: 0 + textureFormatSet: 0 + ignorePngGamma: 0 + applyGammaDecoding: 0 + swizzle: 50462976 + cookieLightType: 0 + platformSettings: + - serializedVersion: 3 + buildTarget: DefaultTexturePlatform + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + ignorePlatformSupport: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + - serializedVersion: 3 + buildTarget: Standalone + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + ignorePlatformSupport: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + - serializedVersion: 3 + buildTarget: Android + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + ignorePlatformSupport: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + - serializedVersion: 3 + buildTarget: Server + maxTextureSize: 2048 + resizeAlgorithm: 0 + textureFormat: -1 + textureCompression: 1 + compressionQuality: 50 + crunchedCompression: 0 + allowsAlphaSplitting: 0 + overridden: 0 + ignorePlatformSupport: 0 + androidETC2FallbackOverride: 0 + forceMaximumCompressionQuality_BC6H_BC7: 0 + spriteSheet: + serializedVersion: 2 + sprites: [] + outline: [] + physicsShape: [] + bones: [] + spriteID: + internalID: 0 + vertices: [] + indices: + edges: [] + weights: [] + secondaryTextures: [] + nameFileIdTable: {} + mipmapLimitGroupName: + pSDRemoveMatte: 0 + userData: + assetBundleName: + assetBundleVariant: diff --git a/Resources/llmunity_trash_icon.png b/Resources/llmunity_trash_icon.png index 7457cc94..6b6c79bb 100644 Binary files a/Resources/llmunity_trash_icon.png and b/Resources/llmunity_trash_icon.png differ diff --git a/Runtime/DBSearch.cs b/Runtime/DBSearch.cs new file mode 100644 index 00000000..758d293b --- /dev/null +++ b/Runtime/DBSearch.cs @@ -0,0 +1,96 @@ +using System; +using System.Collections.Generic; +using Cloud.Unum.USearch; +using System.IO.Compression; +using UnityEngine; + +namespace LLMUnity +{ + [DefaultExecutionOrder(-2)] + public class DBSearch : SearchMethod + { + public USearchIndex index; + [ModelAdvanced] public ScalarKind quantization = ScalarKind.Float16; + [ModelAdvanced] public MetricKind metricKind = MetricKind.Cos; + [ModelAdvanced] public ulong connectivity = 32; + [ModelAdvanced] public ulong expansionAdd = 40; + [ModelAdvanced] public ulong expansionSearch = 16; + private Dictionary)> incrementalSearchCache = new Dictionary)>(); + + public override void Awake() + { + if (!enabled) return; + base.Awake(); + InitIndex(); + } + + public void InitIndex() + { + index = new USearchIndex(metricKind, quantization, (ulong)llm.embeddingLength, connectivity, expansionAdd, expansionSearch, false); + } + + protected override void AddInternal(int key, float[] embedding) + { + index.Add((ulong)key, embedding); + } + + protected override void RemoveInternal(int key) + { + index.Remove((ulong)key); + } + + protected override int[] SearchInternal(float[] embedding, int k, out float[] distances) + { + index.Search(embedding, k, out ulong[] keys, out distances); + int[] intKeys = new int[keys.Length]; + for (int i = 0; i < keys.Length; i++) intKeys[i] = (int)keys[i]; + return intKeys; + } + + public override int IncrementalSearch(float[] embedding) + { + int key = nextIncrementalSearchKey++; + incrementalSearchCache[key] = (embedding, new List()); + return key; + } + + public override (int[], bool) IncrementalFetchKeys(int fetchKey, int k) + { + if (!incrementalSearchCache.ContainsKey(fetchKey)) throw new Exception($"There is no IncrementalSearch cached with this key: {fetchKey}"); + + float[] embedding; + List seenKeys; + (embedding, seenKeys) = incrementalSearchCache[fetchKey]; + int matches = index.Search(embedding, k, out ulong[] keys, out float[] distances, (int key, IntPtr state) => {return seenKeys.Contains(key) ? 0 : 1;}); + int[] intKeys = new int[keys.Length]; + for (int i = 0; i < keys.Length; i++) intKeys[i] = (int)keys[i]; + incrementalSearchCache[fetchKey].Item2.AddRange(intKeys); + + bool completed = matches < k || seenKeys.Count == Count(); + if (completed) IncrementalSearchComplete(fetchKey); + return (intKeys, completed); + } + + public override void IncrementalSearchComplete(int fetchKey) + { + incrementalSearchCache.Remove(fetchKey); + } + + protected override void SaveInternal(ZipArchive archive) + { + index.Save(archive); + } + + protected override void LoadInternal(ZipArchive archive) + { + index.Load(archive); + } + + protected override void ClearInternal() + { + index.Dispose(); + InitIndex(); + incrementalSearchCache.Clear(); + } + } +} diff --git a/Runtime/DBSearch.cs.meta b/Runtime/DBSearch.cs.meta new file mode 100644 index 00000000..aee765d6 --- /dev/null +++ b/Runtime/DBSearch.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: f8c6492f8c97ab9a09b0eb0f93a158da +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index 55f2f379..6585e9b2 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -31,7 +31,7 @@ public class LLM : MonoBehaviour [LLM] public int numGPULayers = 0; /// select to log the output of the LLM in the Unity Editor. [LLM] public bool debug = false; - /// number of prompts that can happen in parallel (-1 = number of LLMCharacter objects) + /// number of prompts that can happen in parallel (-1 = number of LLMCaller objects) [LLMAdvanced] public int parallelPrompts = -1; /// select to not destroy the LLM GameObject when loading a new Scene. [LLMAdvanced] public bool dontDestroyOnLoad = true; @@ -40,7 +40,7 @@ public class LLM : MonoBehaviour [ModelAdvanced] public int contextSize = 0; /// Batch size for prompt processing. [ModelAdvanced] public int batchSize = 512; - /// a base prompt to use as a base for all LLMCharacter objects + /// a base prompt to use as a base for all LLMCaller objects [TextArea(5, 10), ChatAdvanced] public string basePrompt = ""; /// Boolean set to true if the server has started and is ready to receive requests, false otherwise. public bool started { get; protected set; } = false; @@ -78,7 +78,7 @@ public class LLM : MonoBehaviour /// \cond HIDE IntPtr LLMObject = IntPtr.Zero; - List clients = new List(); + List clients = new List(); LLMLib llmlib; StreamWrapper logStreamWrapper = null; Thread llmThread = null; @@ -89,6 +89,8 @@ public class LLM : MonoBehaviour public LoraManager loraManager = new LoraManager(); string loraPre = ""; string loraWeightsPre = ""; + public bool embeddingsOnly = false; + public int embeddingLength = 0; /// \endcond @@ -211,6 +213,7 @@ public void SetModel(string path) ModelEntry modelEntry = LLMManager.Get(model); if (modelEntry == null) modelEntry = new ModelEntry(GetLLMManagerAssetRuntime(model)); SetTemplate(modelEntry.chatTemplate); + SetEmbeddings(modelEntry.embeddingLength, modelEntry.embeddingOnly); if (contextSize == 0 && modelEntry.contextLength > 32768) { LLMUnitySetup.LogWarning($"The model {path} has very large context size ({modelEntry.contextLength}), consider setting it to a smaller value (<=32768) to avoid filling up the RAM"); @@ -314,6 +317,21 @@ public void SetTemplate(string templateName, bool setDirty = true) #endif } + /// + /// Set LLM Embedding parameters + /// + /// number of embedding dimensions + /// if true, the LLM will be used only for embeddings + public void SetEmbeddings(int embeddingLength, bool embeddingsOnly) + { + if (embeddingsOnly) LLMUnitySetup.LogWarning("This model can only be used for embeddings"); + this.embeddingsOnly = embeddingsOnly; + this.embeddingLength = embeddingLength; +#if UNITY_EDITOR + if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); +#endif + } + /// \cond HIDE string ReadFileContents(string path) @@ -397,15 +415,16 @@ protected virtual string GetLlamaccpArguments() int slots = GetNumClients(); string arguments = $"-m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable -np {slots}"; + if (embeddingsOnly) arguments += " --embedding"; + if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}"; + arguments += loraArgument; + arguments += $" -ngl {numGPULayers}"; + if (LLMUnitySetup.FullLlamaLib && flashAttention) arguments += $" --flash-attn"; if (remote) { arguments += $" --port {port} --host 0.0.0.0"; if (!String.IsNullOrEmpty(APIKey)) arguments += $" --api-key {APIKey}"; } - if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}"; - arguments += loraArgument; - arguments += $" -ngl {numGPULayers}"; - if (LLMUnitySetup.FullLlamaLib && flashAttention) arguments += $" --flash-attn"; // the following is the equivalent for running from command line string serverCommand; @@ -414,7 +433,7 @@ protected virtual string GetLlamaccpArguments() serverCommand += " " + arguments; serverCommand += $" --template {chatTemplate}"; if (remote && SSLCert != "" && SSLKey != "") serverCommand += $" --ssl-cert-file {SSLCertPath} --ssl-key-file {SSLKeyPath}"; - LLMUnitySetup.Log($"Server command: {serverCommand}"); + LLMUnitySetup.Log($"Server deployment command: {serverCommand}"); return arguments; } @@ -517,15 +536,15 @@ private void StartService() } /// - /// Registers a local LLMCharacter object. - /// This allows to bind the LLMCharacter "client" to a specific slot of the LLM. + /// Registers a local LLMCaller object. + /// This allows to bind the LLMCaller "client" to a specific slot of the LLM. /// - /// + /// /// - public int Register(LLMCharacter llmCharacter) + public int Register(LLMCaller llmCaller) { - clients.Add(llmCharacter); - int index = clients.IndexOf(llmCharacter); + clients.Add(llmCaller); + int index = clients.IndexOf(llmCaller); if (parallelPrompts != -1) return index % parallelPrompts; return index; } diff --git a/Runtime/LLMCaller.cs b/Runtime/LLMCaller.cs new file mode 100644 index 00000000..e80d6dea --- /dev/null +++ b/Runtime/LLMCaller.cs @@ -0,0 +1,394 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using UnityEngine; +using UnityEngine.Networking; + +namespace LLMUnity +{ + [DefaultExecutionOrder(-2)] + public class LLMCaller : MonoBehaviour + { + /// toggle to show/hide advanced options in the GameObject + [HideInInspector] public bool advancedOptions = false; + /// toggle to use remote LLM server or local LLM + [LocalRemote] public bool remote = false; + /// the LLM object to use + [Local] public LLM llm; + /// option to receive the reply from the model as it is produced (recommended!). + /// If it is not selected, the full reply from the model is received in one go + [Chat] public bool stream = true; + /// allows to use a server with API key + [Remote] public string APIKey; + /// specify which slot of the server to use for computation (affects caching) + [ModelAdvanced] public int slot = -1; + + /// host to use for the LLM server + [Remote] public string host = "localhost"; + /// port to use for the LLM server + [Remote] public int port = 13333; + /// number of retries to use for the LLM server requests (-1 = infinite) + [Remote] public int numRetries = 10; + + private List<(string, string)> requestHeaders; + private List WIPRequests = new List(); + + /// + /// The Unity Awake function that initializes the state before the application starts. + /// The following actions are executed: + /// - the corresponding LLM server is defined (if ran locally) + /// - the grammar is set based on the grammar file + /// - the prompt and chat history are initialised + /// - the chat template is constructed + /// - the number of tokens to keep are based on the system prompt (if setNKeepToPrompt=true) + /// + public virtual void Awake() + { + // Start the LLM server in a cross-platform way + if (!enabled) return; + + requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") }; + if (!remote) + { + AssignLLM(); + if (llm == null) + { + LLMUnitySetup.LogError($"No LLM assigned or detected for LLMCharacter {name}!"); + return; + } + int slotFromServer = llm.Register(this); + if (slot == -1) slot = slotFromServer; + } + else + { + if (!String.IsNullOrEmpty(APIKey)) requestHeaders.Add(("Authorization", "Bearer " + APIKey)); + } + } + + protected virtual void OnValidate() + { + AssignLLM(); + if (llm != null && llm.parallelPrompts > -1 && (slot < -1 || slot >= llm.parallelPrompts)) LLMUnitySetup.LogError($"The slot needs to be between 0 and {llm.parallelPrompts-1}, or -1 to be automatically set"); + } + + protected virtual void Reset() + { + AssignLLM(); + } + + protected virtual void AssignLLM() + { + if (remote || llm != null) return; + + LLM[] existingLLMs = FindObjectsOfType(); + if (existingLLMs.Length == 0) return; + + SortBySceneAndHierarchy(existingLLMs); + llm = existingLLMs[0]; + string msg = $"Assigning LLM {llm.name} to LLMCharacter {name}"; + if (llm.gameObject.scene != gameObject.scene) msg += $" from scene {llm.gameObject.scene}"; + LLMUnitySetup.Log(msg); + } + + protected virtual void SortBySceneAndHierarchy(LLM[] array) + { + for (int i = 0; i < array.Length - 1; i++) + { + bool swapped = false; + for (int j = 0; j < array.Length - i - 1; j++) + { + bool sameScene = array[j].gameObject.scene == array[j + 1].gameObject.scene; + bool swap = ( + (!sameScene && array[j + 1].gameObject.scene == gameObject.scene) || + (sameScene && array[j].transform.GetSiblingIndex() > array[j + 1].transform.GetSiblingIndex()) + ); + if (swap) + { + LLM temp = array[j]; + array[j] = array[j + 1]; + array[j + 1] = temp; + swapped = true; + } + } + if (!swapped) break; + } + } + + protected string ChatContent(ChatResult result) + { + // get content from a chat result received from the endpoint + return result.content.Trim(); + } + + protected string MultiChatContent(MultiChatResult result) + { + // get content from a chat result received from the endpoint + string response = ""; + foreach (ChatResult resultPart in result.data) + { + response += resultPart.content; + } + return response.Trim(); + } + + protected async Task CompletionRequest(string json, Callback callback = null) + { + string result = ""; + if (stream) + { + result = await PostRequest(json, "completion", MultiChatContent, callback); + } + else + { + result = await PostRequest(json, "completion", ChatContent, callback); + } + return result; + } + + protected string TemplateContent(TemplateResult result) + { + // get content from a char result received from the endpoint in open AI format + return result.template; + } + + protected List TokenizeContent(TokenizeResult result) + { + // get the tokens from a tokenize result received from the endpoint + return result.tokens; + } + + protected string DetokenizeContent(TokenizeRequest result) + { + // get content from a chat result received from the endpoint + return result.content; + } + + protected List EmbeddingsContent(EmbeddingsResult result) + { + // get content from a chat result received from the endpoint + return result.embedding; + } + + protected string SlotContent(SlotResult result) + { + // get the tokens from a tokenize result received from the endpoint + return result.filename; + } + + protected virtual Ret ConvertContent(string response, ContentCallback getContent = null) + { + // template function to convert the json received and get the content + if (response == null) return default; + response = response.Trim(); + if (response.StartsWith("data: ")) + { + string responseArray = ""; + foreach (string responsePart in response.Replace("\n\n", "").Split("data: ")) + { + if (responsePart == "") continue; + if (responseArray != "") responseArray += ",\n"; + responseArray += responsePart; + } + response = $"{{\"data\": [{responseArray}]}}"; + } + return getContent(JsonUtility.FromJson(response)); + } + + protected virtual void CancelRequestsLocal() + { + if (slot >= 0) llm.CancelRequest(slot); + } + + protected virtual void CancelRequestsRemote() + { + foreach (UnityWebRequest request in WIPRequests) + { + request.Abort(); + } + WIPRequests.Clear(); + } + + /// + /// Cancel the ongoing requests e.g. Chat, Complete. + /// + // + public void CancelRequests() + { + if (remote) CancelRequestsRemote(); + else CancelRequestsLocal(); + } + + protected virtual async Task PostRequestLocal(string json, string endpoint, ContentCallback getContent, Callback callback = null) + { + // send a post request to the server and call the relevant callbacks to convert the received content and handle it + // this function has streaming functionality i.e. handles the answer while it is being received + string callResult = null; + bool callbackCalled = false; + while (!llm.failed && !llm.started) await Task.Yield(); + switch (endpoint) + { + case "tokenize": + callResult = await llm.Tokenize(json); + break; + case "detokenize": + callResult = await llm.Detokenize(json); + break; + case "embeddings": + callResult = await llm.Embeddings(json); + break; + case "slots": + callResult = await llm.Slot(json); + break; + case "completion": + if (llm.embeddingsOnly) LLMUnitySetup.LogError("The LLM can't be used for completion, only for embeddings"); + else + { + Callback callbackString = null; + if (stream && callback != null) + { + if (typeof(Ret) == typeof(string)) + { + callbackString = (strArg) => + { + callback(ConvertContent(strArg, getContent)); + }; + } + else + { + LLMUnitySetup.LogError($"wrong callback type, should be string"); + } + callbackCalled = true; + } + callResult = await llm.Completion(json, callbackString); + } + break; + default: + LLMUnitySetup.LogError($"Unknown endpoint {endpoint}"); + break; + } + + Ret result = ConvertContent(callResult, getContent); + if (!callbackCalled) callback?.Invoke(result); + return result; + } + + protected virtual async Task PostRequestRemote(string json, string endpoint, ContentCallback getContent, Callback callback = null) + { + // send a post request to the server and call the relevant callbacks to convert the received content and handle it + // this function has streaming functionality i.e. handles the answer while it is being received + if (endpoint == "slots") + { + LLMUnitySetup.LogError("Saving and loading is not currently supported in remote setting"); + return default; + } + + Ret result = default; + byte[] jsonToSend = new System.Text.UTF8Encoding().GetBytes(json); + UnityWebRequest request = null; + string error = null; + int tryNr = numRetries; + + while (tryNr != 0) + { + using (request = UnityWebRequest.Put($"{host}:{port}/{endpoint}", jsonToSend)) + { + WIPRequests.Add(request); + + request.method = "POST"; + if (requestHeaders != null) + { + for (int i = 0; i < requestHeaders.Count; i++) + request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2); + } + + // Start the request asynchronously + var asyncOperation = request.SendWebRequest(); + float lastProgress = 0f; + // Continue updating progress until the request is completed + while (!asyncOperation.isDone) + { + float currentProgress = request.downloadProgress; + // Check if progress has changed + if (currentProgress != lastProgress && callback != null) + { + callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent)); + lastProgress = currentProgress; + } + // Wait for the next frame + await Task.Yield(); + } + WIPRequests.Remove(request); + if (request.result == UnityWebRequest.Result.Success) + { + result = ConvertContent(request.downloadHandler.text, getContent); + error = null; + break; + } + else + { + result = default; + error = request.error; + if (request.responseCode == (int)System.Net.HttpStatusCode.Unauthorized) break; + } + } + tryNr--; + if (tryNr > 0) await Task.Delay(200 * (numRetries - tryNr)); + } + + if (error != null) LLMUnitySetup.LogError(error); + callback?.Invoke(result); + return result; + } + + protected virtual async Task PostRequest(string json, string endpoint, ContentCallback getContent, Callback callback = null) + { + if (remote) return await PostRequestRemote(json, endpoint, getContent, callback); + return await PostRequestLocal(json, endpoint, getContent, callback); + } + + /// + /// Tokenises the provided query. + /// + /// query to tokenise + /// callback function called with the result tokens + /// list of the tokens + public async Task> Tokenize(string query, Callback> callback = null) + { + // handle the tokenization of a message by the user + TokenizeRequest tokenizeRequest = new TokenizeRequest(); + tokenizeRequest.content = query; + string json = JsonUtility.ToJson(tokenizeRequest); + return await PostRequest>(json, "tokenize", TokenizeContent, callback); + } + + /// + /// Detokenises the provided tokens to a string. + /// + /// tokens to detokenise + /// callback function called with the result string + /// the detokenised string + public async Task Detokenize(List tokens, Callback callback = null) + { + // handle the detokenization of a message by the user + TokenizeResult tokenizeRequest = new TokenizeResult(); + tokenizeRequest.tokens = tokens; + string json = JsonUtility.ToJson(tokenizeRequest); + return await PostRequest(json, "detokenize", DetokenizeContent, callback); + } + + /// + /// Computes the embeddings of the provided input. + /// + /// input to compute the embeddings for + /// callback function called with the result string + /// the computed embeddings + public async Task> Embeddings(string query, Callback> callback = null) + { + // handle the tokenization of a message by the user + TokenizeRequest tokenizeRequest = new TokenizeRequest(); + tokenizeRequest.content = query; + string json = JsonUtility.ToJson(tokenizeRequest); + return await PostRequest>(json, "embeddings", EmbeddingsContent, callback); + } + } +} diff --git a/Runtime/LLMCaller.cs.meta b/Runtime/LLMCaller.cs.meta new file mode 100644 index 00000000..03714d58 --- /dev/null +++ b/Runtime/LLMCaller.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 01594396c2699f0ecb48ead86a6b1bc5 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index 8b7ee452..f57a25f9 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -7,7 +7,6 @@ using System.Threading.Tasks; using UnityEditor; using UnityEngine; -using UnityEngine.Networking; namespace LLMUnity { @@ -16,22 +15,8 @@ namespace LLMUnity /// /// Class implementing the LLM characters. /// - public class LLMCharacter : MonoBehaviour + public class LLMCharacter : LLMCaller { - /// toggle to show/hide advanced options in the GameObject - [HideInInspector] public bool advancedOptions = false; - /// toggle to use remote LLM server or local LLM - [LocalRemote] public bool remote = false; - /// the LLM object to use - [Local] public LLM llm; - /// host to use for the LLM server - [Remote] public string host = "localhost"; - /// port to use for the LLM server - [Remote] public int port = 13333; - /// number of retries to use for the LLM server requests (-1 = infinite) - [Remote] public int numRetries = 10; - /// allows to use a server with API key - [Remote] public string APIKey; /// file to save the chat history. /// The file is saved only for Chat calls with addToHistory set to true. /// The file will be saved within the persistentDataPath directory (see https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html). @@ -40,15 +25,10 @@ public class LLMCharacter : MonoBehaviour [LLM] public bool saveCache = false; /// select to log the constructed prompt the Unity Editor. [LLM] public bool debugPrompt = false; - /// option to receive the reply from the model as it is produced (recommended!). - /// If it is not selected, the full reply from the model is received in one go - [Model] public bool stream = true; /// grammar file used for the LLM in .cbnf format (relative to the Assets/StreamingAssets folder) [ModelAdvanced] public string grammar = null; /// option to cache the prompt as it is being created by the chat to avoid reprocessing the entire prompt every time (default: true) [ModelAdvanced] public bool cachePrompt = true; - /// specify which slot of the server to use for computation (affects caching) - [ModelAdvanced] public int slot = -1; /// seed for reproducibility. For random results every time set to -1. [ModelAdvanced] public int seed = 0; /// number of tokens to predict (-1 = infinity, -2 = until context filled). @@ -127,8 +107,6 @@ public class LLMCharacter : MonoBehaviour private string chatTemplate; private ChatTemplate template = null; public string grammarString; - private List<(string, string)> requestHeaders; - private List WIPRequests = new List(); /// \endcond /// @@ -140,81 +118,14 @@ public class LLMCharacter : MonoBehaviour /// - the chat template is constructed /// - the number of tokens to keep are based on the system prompt (if setNKeepToPrompt=true) /// - public void Awake() + public override void Awake() { - // Start the LLM server in a cross-platform way if (!enabled) return; - - requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") }; - if (!remote) - { - AssignLLM(); - if (llm == null) - { - LLMUnitySetup.LogError($"No LLM assigned or detected for LLMCharacter {name}!"); - return; - } - int slotFromServer = llm.Register(this); - if (slot == -1) slot = slotFromServer; - } - else - { - if (!String.IsNullOrEmpty(APIKey)) requestHeaders.Add(("Authorization", "Bearer " + APIKey)); - } - + base.Awake(); InitGrammar(); InitHistory(); } - void OnValidate() - { - AssignLLM(); - if (llm != null && llm.parallelPrompts > -1 && (slot < -1 || slot >= llm.parallelPrompts)) LLMUnitySetup.LogError($"The slot needs to be between 0 and {llm.parallelPrompts-1}, or -1 to be automatically set"); - } - - void Reset() - { - AssignLLM(); - } - - void AssignLLM() - { - if (remote || llm != null) return; - - LLM[] existingLLMs = FindObjectsOfType(); - if (existingLLMs.Length == 0) return; - - SortBySceneAndHierarchy(existingLLMs); - llm = existingLLMs[0]; - string msg = $"Assigning LLM {llm.name} to LLMCharacter {name}"; - if (llm.gameObject.scene != gameObject.scene) msg += $" from scene {llm.gameObject.scene}"; - LLMUnitySetup.Log(msg); - } - - void SortBySceneAndHierarchy(LLM[] array) - { - for (int i = 0; i < array.Length - 1; i++) - { - bool swapped = false; - for (int j = 0; j < array.Length - i - 1; j++) - { - bool sameScene = array[j].gameObject.scene == array[j + 1].gameObject.scene; - bool swap = ( - (!sameScene && array[j + 1].gameObject.scene == gameObject.scene) || - (sameScene && array[j].transform.GetSiblingIndex() > array[j + 1].transform.GetSiblingIndex()) - ); - if (swap) - { - LLM temp = array[j]; - array[j] = array[j + 1]; - array[j + 1] = temp; - swapped = true; - } - } - if (!swapped) break; - } - } - protected void InitHistory() { InitPrompt(); @@ -416,67 +327,6 @@ public void AddAIMessage(string content) AddMessage(AIName, content); } - protected string ChatContent(ChatResult result) - { - // get content from a chat result received from the endpoint - return result.content.Trim(); - } - - protected string MultiChatContent(MultiChatResult result) - { - // get content from a chat result received from the endpoint - string response = ""; - foreach (ChatResult resultPart in result.data) - { - response += resultPart.content; - } - return response.Trim(); - } - - async Task CompletionRequest(string json, Callback callback = null) - { - string result = ""; - if (stream) - { - result = await PostRequest(json, "completion", MultiChatContent, callback); - } - else - { - result = await PostRequest(json, "completion", ChatContent, callback); - } - return result; - } - - protected string TemplateContent(TemplateResult result) - { - // get content from a char result received from the endpoint in open AI format - return result.template; - } - - protected List TokenizeContent(TokenizeResult result) - { - // get the tokens from a tokenize result received from the endpoint - return result.tokens; - } - - protected string DetokenizeContent(TokenizeRequest result) - { - // get content from a chat result received from the endpoint - return result.content; - } - - protected List EmbeddingsContent(EmbeddingsResult result) - { - // get content from a chat result received from the endpoint - return result.embedding; - } - - protected string SlotContent(SlotResult result) - { - // get the tokens from a tokenize result received from the endpoint - return result.filename; - } - /// /// Chat functionality of the LLM. /// It calls the LLM completion based on the provided query including the previous chat history. @@ -578,51 +428,6 @@ public async Task AskTemplate() return await PostRequest("{}", "template", TemplateContent); } - /// - /// Tokenises the provided query. - /// - /// query to tokenise - /// callback function called with the result tokens - /// list of the tokens - public async Task> Tokenize(string query, Callback> callback = null) - { - // handle the tokenization of a message by the user - TokenizeRequest tokenizeRequest = new TokenizeRequest(); - tokenizeRequest.content = query; - string json = JsonUtility.ToJson(tokenizeRequest); - return await PostRequest>(json, "tokenize", TokenizeContent, callback); - } - - /// - /// Detokenises the provided tokens to a string. - /// - /// tokens to detokenise - /// callback function called with the result string - /// the detokenised string - public async Task Detokenize(List tokens, Callback callback = null) - { - // handle the detokenization of a message by the user - TokenizeResult tokenizeRequest = new TokenizeResult(); - tokenizeRequest.tokens = tokens; - string json = JsonUtility.ToJson(tokenizeRequest); - return await PostRequest(json, "detokenize", DetokenizeContent, callback); - } - - /// - /// Computes the embeddings of the provided input. - /// - /// input to compute the embeddings for - /// callback function called with the result string - /// the computed embeddings - public async Task> Embeddings(string query, Callback> callback = null) - { - // handle the tokenization of a message by the user - TokenizeRequest tokenizeRequest = new TokenizeRequest(); - tokenizeRequest.content = query; - string json = JsonUtility.ToJson(tokenizeRequest); - return await PostRequest>(json, "embeddings", EmbeddingsContent, callback); - } - protected async Task Slot(string filepath, string action) { SlotRequest slotRequest = new SlotRequest(); @@ -676,173 +481,6 @@ public virtual async Task Load(string filename) string result = await Slot(cachepath, "restore"); return result; } - - protected Ret ConvertContent(string response, ContentCallback getContent = null) - { - // template function to convert the json received and get the content - if (response == null) return default; - response = response.Trim(); - if (response.StartsWith("data: ")) - { - string responseArray = ""; - foreach (string responsePart in response.Replace("\n\n", "").Split("data: ")) - { - if (responsePart == "") continue; - if (responseArray != "") responseArray += ",\n"; - responseArray += responsePart; - } - response = $"{{\"data\": [{responseArray}]}}"; - } - return getContent(JsonUtility.FromJson(response)); - } - - protected void CancelRequestsLocal() - { - if (slot >= 0) llm.CancelRequest(slot); - } - - protected void CancelRequestsRemote() - { - foreach (UnityWebRequest request in WIPRequests) - { - request.Abort(); - } - WIPRequests.Clear(); - } - - /// - /// Cancel the ongoing requests e.g. Chat, Complete. - /// - // - public void CancelRequests() - { - if (remote) CancelRequestsRemote(); - else CancelRequestsLocal(); - } - - protected async Task PostRequestLocal(string json, string endpoint, ContentCallback getContent, Callback callback = null) - { - // send a post request to the server and call the relevant callbacks to convert the received content and handle it - // this function has streaming functionality i.e. handles the answer while it is being received - string callResult = null; - bool callbackCalled = false; - while (!llm.failed && !llm.started) await Task.Yield(); - switch (endpoint) - { - case "tokenize": - callResult = await llm.Tokenize(json); - break; - case "detokenize": - callResult = await llm.Detokenize(json); - break; - case "embeddings": - callResult = await llm.Embeddings(json); - break; - case "slots": - callResult = await llm.Slot(json); - break; - case "completion": - Callback callbackString = null; - if (stream && callback != null) - { - if (typeof(Ret) == typeof(string)) - { - callbackString = (strArg) => - { - callback(ConvertContent(strArg, getContent)); - }; - } - else - { - LLMUnitySetup.LogError($"wrong callback type, should be string"); - } - callbackCalled = true; - } - callResult = await llm.Completion(json, callbackString); - break; - default: - LLMUnitySetup.LogError($"Unknown endpoint {endpoint}"); - break; - } - - Ret result = ConvertContent(callResult, getContent); - if (!callbackCalled) callback?.Invoke(result); - return result; - } - - protected async Task PostRequestRemote(string json, string endpoint, ContentCallback getContent, Callback callback = null) - { - // send a post request to the server and call the relevant callbacks to convert the received content and handle it - // this function has streaming functionality i.e. handles the answer while it is being received - if (endpoint == "slots") - { - LLMUnitySetup.LogError("Saving and loading is not currently supported in remote setting"); - return default; - } - - Ret result = default; - byte[] jsonToSend = new System.Text.UTF8Encoding().GetBytes(json); - UnityWebRequest request = null; - string error = null; - int tryNr = numRetries; - - while (tryNr != 0) - { - using (request = UnityWebRequest.Put($"{host}:{port}/{endpoint}", jsonToSend)) - { - WIPRequests.Add(request); - - request.method = "POST"; - if (requestHeaders != null) - { - for (int i = 0; i < requestHeaders.Count; i++) - request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2); - } - - // Start the request asynchronously - var asyncOperation = request.SendWebRequest(); - float lastProgress = 0f; - // Continue updating progress until the request is completed - while (!asyncOperation.isDone) - { - float currentProgress = request.downloadProgress; - // Check if progress has changed - if (currentProgress != lastProgress && callback != null) - { - callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent)); - lastProgress = currentProgress; - } - // Wait for the next frame - await Task.Yield(); - } - WIPRequests.Remove(request); - if (request.result == UnityWebRequest.Result.Success) - { - result = ConvertContent(request.downloadHandler.text, getContent); - error = null; - break; - } - else - { - result = default; - error = request.error; - if (request.responseCode == (int)System.Net.HttpStatusCode.Unauthorized) break; - } - } - tryNr--; - if (tryNr > 0) await Task.Delay(200 * (numRetries - tryNr)); - } - - if (error != null) LLMUnitySetup.LogError(error); - callback?.Invoke(result); - return result; - } - - protected async Task PostRequest(string json, string endpoint, ContentCallback getContent, Callback callback = null) - { - if (remote) return await PostRequestRemote(json, endpoint, getContent, callback); - return await PostRequestLocal(json, endpoint, getContent, callback); - } } /// \cond HIDE diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index ba97ca3b..d4163ff0 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -16,9 +16,13 @@ public class ModelEntry public bool lora; public string chatTemplate; public string url; + public bool embeddingOnly; + public int embeddingLength; public bool includeInBuild; public int contextLength; + static List embeddingOnlyArchs = new List {"bert", "nomic-bert", "jina-bert-v2", "t5", "t5encoder"}; + public static string GetFilenameOrRelativeAssetPath(string path) { string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed @@ -40,12 +44,16 @@ public ModelEntry(string path, bool lora = false, string label = null, string ur includeInBuild = true; chatTemplate = null; contextLength = -1; + embeddingOnly = false; + embeddingLength = 0; if (!lora) { GGUFReader reader = new GGUFReader(this.path); chatTemplate = ChatTemplate.FromGGUF(reader, this.path); string arch = reader.GetStringField("general.architecture"); if (arch != null) contextLength = reader.GetIntField($"{arch}.context_length"); + if (arch != null) embeddingLength = reader.GetIntField($"{arch}.embedding_length"); + embeddingOnly = embeddingOnlyArchs.Contains(arch); } } diff --git a/Runtime/LLMUnitySetup.cs b/Runtime/LLMUnitySetup.cs index fe199c3c..54a2d587 100644 --- a/Runtime/LLMUnitySetup.cs +++ b/Runtime/LLMUnitySetup.cs @@ -106,12 +106,18 @@ public class LLMUnitySetup /// Default models for download [HideInInspector] public static readonly (string, string, string)[] modelOptions = new(string, string, string)[] { - ("Llama 3.1 8B (medium, best overall)", "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf?download=true", "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/LICENSE"), - ("Gemma 2 9B it (medium, great overall)", "https://huggingface.co/bartowski/gemma-2-9b-it-GGUF/resolve/main/gemma-2-9b-it-Q4_K_M.gguf?download=true", "https://ai.google.dev/gemma/terms"), - ("Mistral 7B Instruct v0.2 (medium, great overall)", "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf?download=true", null), - ("OpenHermes 2.5 7B (medium, good for conversation)", "https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf?download=true", null), - ("Phi 3 (small, great small model)", "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf?download=true", null), - ("Qwen 2 0.5B (tiny, useful for mobile)", "https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/resolve/main/qwen2-0_5b-instruct-q4_k_m.gguf?download=true", null), + // completion models + ("πŸ’¬ Llama 3.1 8B (large, best overall)", "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", "https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/LICENSE"), + ("πŸ’¬ Gemma 2 9B it (large, great overall)", "https://huggingface.co/bartowski/gemma-2-9b-it-GGUF/resolve/main/gemma-2-9b-it-Q4_K_M.gguf", "https://ai.google.dev/gemma/terms"), + ("πŸ’¬ Mistral 7B Instruct v0.2 (large, great overall)", "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf", null), + ("πŸ’¬ OpenHermes 2.5 7B (large, good for conversation)", "https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf", null), + ("πŸ’¬ Phi 3 (medium, great small model)", "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf", null), + ("πŸ’¬ Qwen 2 0.5B (tiny, useful for mobile)", "https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/resolve/main/qwen2-0_5b-instruct-q4_k_m.gguf", null), + // embedding models + ("πŸ” BGE large en v1.5 (large)", "https://huggingface.co/CompendiumLabs/bge-large-en-v1.5-gguf/resolve/main/bge-large-en-v1.5-q4_k_m.gguf", null), + ("πŸ” BGE base en v1.5 (medium)", "https://huggingface.co/CompendiumLabs/bge-base-en-v1.5-gguf/resolve/main/bge-base-en-v1.5-q4_k_m.gguf", null), + ("πŸ” BGE small en v1.5 (small)", "https://huggingface.co/CompendiumLabs/bge-small-en-v1.5-gguf/resolve/main/bge-small-en-v1.5-q4_k_m.gguf", null), + ("πŸ” All MiniLM L12 v2 (small, benchmark)", "https://huggingface.co/leliuga/all-MiniLM-L12-v2-GGUF/resolve/main/all-MiniLM-L12-v2.Q4_K_M.gguf", null), }; /// \cond HIDE diff --git a/Runtime/Search.cs b/Runtime/Search.cs new file mode 100644 index 00000000..7821a5ea --- /dev/null +++ b/Runtime/Search.cs @@ -0,0 +1,224 @@ +using System.Collections.Generic; +using System.IO; +using System.IO.Compression; +using System.Runtime.Serialization.Formatters.Binary; +using System.Threading.Tasks; +using UnityEngine; + +namespace LLMUnity +{ + public class ArchiveSaver + { + public delegate void ArchiveSaverCallback(ZipArchive archive); + + public static void Save(string filePath, ArchiveSaverCallback callback) + { + using (FileStream stream = new FileStream(filePath, FileMode.Create)) + using (ZipArchive archive = new ZipArchive(stream, ZipArchiveMode.Create)) + { + callback(archive); + } + } + + public static void Load(string filePath, ArchiveSaverCallback callback) + { + using (FileStream stream = new FileStream(filePath, FileMode.Open)) + using (ZipArchive archive = new ZipArchive(stream, ZipArchiveMode.Read)) + { + callback(archive); + } + } + + public static void Save(ZipArchive archive, object saveObject, string name) + { + ZipArchiveEntry mainEntry = archive.CreateEntry(name); + using (Stream entryStream = mainEntry.Open()) + { + BinaryFormatter formatter = new BinaryFormatter(); + formatter.Serialize(entryStream, saveObject); + } + } + + public static T Load(ZipArchive archive, string name) + { + ZipArchiveEntry baseEntry = archive.GetEntry(name); + using (Stream entryStream = baseEntry.Open()) + { + BinaryFormatter formatter = new BinaryFormatter(); + T obj = (T)formatter.Deserialize(entryStream); + return obj; + } + } + } + + public interface ISearchable + { + public string Get(int key); + public Task Add(string inputString); + public int Remove(string inputString); + public void Remove(int key); + public int Count(); + public void Clear(); + public Task Search(string queryString, int k); + public void Save(string filePath); + public void Save(ZipArchive archive); + public void Load(string filePath); + public void Load(ZipArchive archive); + } + + [DefaultExecutionOrder(-2)] + public abstract class SearchMethod : LLMCaller, ISearchable + { + [HideInInspector, SerializeField] protected int nextKey = 0; + [HideInInspector, SerializeField] protected int nextIncrementalSearchKey = 0; + + protected SortedDictionary data = new SortedDictionary(); + + public abstract int IncrementalSearch(float[] embedding); + public abstract (int[], bool) IncrementalFetchKeys(int fetchKey, int k); + public abstract void IncrementalSearchComplete(int fetchKey); + protected abstract int[] SearchInternal(float[] encoding, int k, out float[] distances); + protected abstract void AddInternal(int key, float[] embedding); + protected abstract void RemoveInternal(int key); + protected abstract void ClearInternal(); + protected abstract void SaveInternal(ZipArchive archive); + protected abstract void LoadInternal(ZipArchive archive); + + public virtual async Task Encode(string inputString) + { + return (await Embeddings(inputString.Trim())).ToArray(); + } + + public virtual string Get(int key) + { + return data[key]; + } + + public virtual async Task Add(string inputString) + { + int key = nextKey++; + AddInternal(key, await Encode(inputString)); + data[key] = inputString; + return key; + } + + public virtual void Remove(int key) + { + data.Remove(key); + RemoveInternal(key); + } + + public virtual void Clear() + { + data.Clear(); + ClearInternal(); + nextKey = 0; + nextIncrementalSearchKey = 0; + } + + public virtual int Remove(string inputString) + { + List removeIds = new List(); + foreach (var entry in data) + { + if (entry.Value == inputString) removeIds.Add(entry.Key); + } + foreach (int id in removeIds) Remove(id); + return removeIds.Count; + } + + public virtual int Count() + { + return data.Count; + } + + public virtual string[] Search(float[] encoding, int k) + { + int[] keys = SearchInternal(encoding, k, out float[] distances); + string[] result = new string[keys.Length]; + for (int i = 0; i < keys.Length; i++) result[i] = Get(keys[i]); + return result; + } + + public virtual async Task Search(string queryString, int k) + { + return Search(await Encode(queryString), k); + } + + public virtual async Task IncrementalSearch(string queryString) + { + return IncrementalSearch(await Encode(queryString)); + } + + public virtual (string[], bool) IncrementalFetch(int fetchKey, int k) + { + int[] resultKeys; + bool completed; + (resultKeys, completed) = IncrementalFetchKeys(fetchKey, k); + string[] results = new string[resultKeys.Length]; + for (int i = 0; i < resultKeys.Length; i++) results[i] = Get(resultKeys[i]); + return (results, completed); + } + + public virtual void Save(string filePath) + { + ArchiveSaver.Save(filePath, Save); + } + + public virtual void Load(string filePath) + { + ArchiveSaver.Load(filePath, Load); + } + + public virtual void Save(ZipArchive archive) + { + ArchiveSaver.Save(archive, JsonUtility.ToJson(this), "Search_object"); + ArchiveSaver.Save(archive, data, "Search_data"); + SaveInternal(archive); + } + + public virtual void Load(ZipArchive archive) + { + JsonUtility.FromJsonOverwrite(ArchiveSaver.Load(archive, "Search_object"), this); + data = ArchiveSaver.Load>(archive, "Search_data"); + LoadInternal(archive); + } + } + + public abstract class SearchPlugin : MonoBehaviour, ISearchable + { + public SearchMethod search; + + public abstract string Get(int key); + public abstract Task Add(string inputString); + public abstract int Remove(string inputString); + public abstract void Remove(int key); + public abstract int Count(); + public abstract void Clear(); + public abstract Task Search(string queryString, int k); + protected abstract void SaveInternal(ZipArchive archive); + protected abstract void LoadInternal(ZipArchive archive); + + public virtual void Save(string filePath) + { + ArchiveSaver.Save(filePath, Save); + } + + public virtual void Load(string filePath) + { + ArchiveSaver.Load(filePath, Load); + } + + public virtual void Save(ZipArchive archive) + { + ArchiveSaver.Save(archive, JsonUtility.ToJson(this, true), "SearchPlugin_object"); + SaveInternal(archive); + } + + public virtual void Load(ZipArchive archive) + { + JsonUtility.FromJsonOverwrite(ArchiveSaver.Load(archive, "SearchPlugin_object"), this); + LoadInternal(archive); + } + } +} diff --git a/Runtime/Search.cs.meta b/Runtime/Search.cs.meta new file mode 100644 index 00000000..4d857653 --- /dev/null +++ b/Runtime/Search.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 98f7725a31f2e7d8485b2cdf541fc8d4 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/SentenceSplitter.cs b/Runtime/SentenceSplitter.cs new file mode 100644 index 00000000..2d39287f --- /dev/null +++ b/Runtime/SentenceSplitter.cs @@ -0,0 +1,139 @@ +using System; +using System.Collections.Generic; +using System.IO.Compression; +using System.Linq; +using System.Threading.Tasks; +using UnityEngine; + +namespace LLMUnity +{ + [Serializable] + public class SentenceSplitter : SearchPlugin + { + public const string DefaultDelimiters = ".!:;?\n\r"; + public string delimiters = DefaultDelimiters; + public bool returnChunks = false; + public Dictionary phraseToSentences = new Dictionary(); + public Dictionary sentenceToPhrase = new Dictionary(); + [HideInInspector, SerializeField] protected int nextKey = 0; + + public List<(int, int)> Split(string input) + { + List<(int, int)> indices = new List<(int, int)>(); + int startIndex = 0; + for (int i = 0; i < input.Length; i++) + { + if (delimiters.Contains(input[i]) || i == input.Length - 1) + { + if (i > startIndex) indices.Add((startIndex, i)); + startIndex = i + 1; + } + } + return indices; + } + + public override string Get(int key) + { + string phrase = ""; + foreach (int sentenceId in phraseToSentences[key]) phrase += search.Get(sentenceId); + return phrase; + } + + public override async Task Add(string inputString) + { + int key = nextKey++; + List sentenceIds = new List(); + foreach ((int startIndex, int endIndex) in Split(inputString).ToArray()) + { + string sentenceText = inputString.Substring(startIndex, endIndex - startIndex + 1); + sentenceIds.Add(await search.Add(sentenceText)); + } + phraseToSentences[key] = sentenceIds.ToArray(); + return key; + } + + public override void Remove(int key) + { + phraseToSentences.TryGetValue(key, out int[] sentenceIds); + if (sentenceIds == null) return; + phraseToSentences.Remove(key); + foreach (int sentenceId in sentenceIds) search.Remove(sentenceId); + } + + public override int Remove(string inputString) + { + List removeIds = new List(); + foreach (var entry in phraseToSentences) + { + string phrase = ""; + foreach (int sentenceId in entry.Value) + { + phrase += search.Get(sentenceId); + if (phrase.Length > inputString.Length) break; + } + if (phrase == inputString) removeIds.Add(entry.Key); + } + foreach (int id in removeIds) Remove(id); + return removeIds.Count; + } + + public override int Count() + { + return phraseToSentences.Count; + } + + public override async Task Search(string queryString, int k) + { + if (returnChunks) + { + return await search.Search(queryString, k); + } + else + { + int searchKey = await search.IncrementalSearch(queryString); + List phraseKeys = new List(); + List phrases = new List(); + bool complete; + do + { + int[] resultKeys; + (resultKeys, complete) = search.IncrementalFetchKeys(searchKey, k); + for (int i = 0; i < resultKeys.Length; i++) + { + int phraseId = sentenceToPhrase[resultKeys[i]]; + if (phraseKeys.Contains(phraseId)) continue; + phraseKeys.Add(phraseId); + phrases.Add(Get(phraseId)); + if (phraseKeys.Count() == k) + { + complete = true; + break; + } + } + } + while (!complete); + return phrases.ToArray(); + } + } + + public override void Clear() + { + nextKey = 0; + phraseToSentences.Clear(); + sentenceToPhrase.Clear(); + search.Clear(); + } + + protected override void SaveInternal(ZipArchive archive) + { + ArchiveSaver.Save(archive, phraseToSentences, "SentenceSplitter_phraseToSentences"); + ArchiveSaver.Save(archive, sentenceToPhrase, "SentenceSplitter_sentenceToPhrase"); + } + + protected override void LoadInternal(ZipArchive archive) + { + phraseToSentences = ArchiveSaver.Load>(archive, "SentenceSplitter_phraseToSentences"); + sentenceToPhrase = ArchiveSaver.Load>(archive, "SentenceSplitter_sentenceToPhrase"); + } + } +} diff --git a/Runtime/SentenceSplitter.cs.meta b/Runtime/SentenceSplitter.cs.meta new file mode 100644 index 00000000..6102827f --- /dev/null +++ b/Runtime/SentenceSplitter.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 354c2418b0d6913efbfb73dccd540d23 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/SimpleSearch.cs b/Runtime/SimpleSearch.cs new file mode 100644 index 00000000..b0073649 --- /dev/null +++ b/Runtime/SimpleSearch.cs @@ -0,0 +1,128 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.IO.Compression; +using UnityEngine; + +namespace LLMUnity +{ + [DefaultExecutionOrder(-2)] + public class SimpleSearch : SearchMethod + { + protected SortedDictionary embeddings = new SortedDictionary(); + protected Dictionary> incrementalSearchCache = new Dictionary>(); + + protected override void AddInternal(int key, float[] embedding) + { + embeddings[key] = embedding; + } + + protected override void RemoveInternal(int key) + { + embeddings.Remove(key); + } + + public static float DotProduct(float[] vector1, float[] vector2) + { + if (vector1.Length != vector2.Length) + { + throw new ArgumentException("Vector lengths must be equal for dot product calculation"); + } + float result = 0; + for (int i = 0; i < vector1.Length; i++) + { + result += vector1[i] * vector2[i]; + } + return result; + } + + public static float InverseDotProduct(float[] vector1, float[] vector2) + { + return 1 - DotProduct(vector1, vector2); + } + + public static float[] InverseDotProduct(float[] vector1, float[][] vector2) + { + float[] results = new float[vector2.Length]; + for (int i = 0; i < vector2.Length; i++) + { + results[i] = InverseDotProduct(vector1, vector2[i]); + } + return results; + } + + protected override int[] SearchInternal(float[] embedding, int k, out float[] distances) + { + float[] unsortedDistances = InverseDotProduct(embedding, embeddings.Values.ToArray()); + var sortedLists = embeddings.Keys.Zip(unsortedDistances, (first, second) => new { First = first, Second = second }) + .OrderBy(item => item.Second) + .ToList(); + int kmax = k == -1 ? sortedLists.Count : Math.Min(k, sortedLists.Count); + int[] results = new int[kmax]; + distances = new float[kmax]; + for (int i = 0; i < kmax; i++) + { + results[i] = sortedLists[i].First; + distances[i] = sortedLists[i].Second; + } + return results; + } + + public override int IncrementalSearch(float[] embedding) + { + int key = nextIncrementalSearchKey++; + float[] unsortedDistances = InverseDotProduct(embedding, embeddings.Values.ToArray()); + incrementalSearchCache[key] = embeddings.Keys.Zip(unsortedDistances, (first, second) => (first, second)) + .OrderBy(item => item.second) + .ToList(); + return key; + } + + public override (int[], bool) IncrementalFetchKeys(int fetchKey, int k) + { + if (!incrementalSearchCache.ContainsKey(fetchKey)) throw new Exception($"There is no IncrementalSearch cached with this key: {fetchKey}"); + + bool completed; + List<(int, float)> sortedLists; + if (k == -1) + { + sortedLists = incrementalSearchCache[fetchKey]; + completed = true; + } + else + { + sortedLists = incrementalSearchCache[fetchKey].GetRange(0, k); + incrementalSearchCache[fetchKey].RemoveRange(0, k); + completed = incrementalSearchCache[fetchKey].Count == 0; + } + if (completed) IncrementalSearchComplete(fetchKey); + + List results = new List(); + foreach ((int key, float distance) in sortedLists) results.Add(key); + return (results.ToArray(), completed); + } + + public override void IncrementalSearchComplete(int fetchKey) + { + incrementalSearchCache.Remove(fetchKey); + } + + protected override void ClearInternal() + { + embeddings.Clear(); + incrementalSearchCache.Clear(); + } + + protected override void SaveInternal(ZipArchive archive) + { + ArchiveSaver.Save(archive, embeddings, "SimpleSearch_embeddings"); + ArchiveSaver.Save(archive, incrementalSearchCache, "SimpleSearch_incrementalSearchCache"); + } + + protected override void LoadInternal(ZipArchive archive) + { + embeddings = ArchiveSaver.Load>(archive, "SimpleSearch_embeddings"); + incrementalSearchCache = ArchiveSaver.Load>>(archive, "SimpleSearch_incrementalSearchCache"); + } + } +} diff --git a/Runtime/SimpleSearch.cs.meta b/Runtime/SimpleSearch.cs.meta new file mode 100644 index 00000000..21d7f0ff --- /dev/null +++ b/Runtime/SimpleSearch.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 5102d32385d84d87f98c64d376cbcc90 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/undream.llmunity.Runtime.asmdef b/Runtime/undream.llmunity.Runtime.asmdef index 140aec56..800db04c 100644 --- a/Runtime/undream.llmunity.Runtime.asmdef +++ b/Runtime/undream.llmunity.Runtime.asmdef @@ -1,3 +1,6 @@ { - "name": "undream.llmunity.Runtime" + "name": "undream.llmunity.Runtime", + "references": [ + "Cloud.Unum.USearch" + ] } \ No newline at end of file diff --git a/Samples~/KnowledgeBaseGame/KnowledgeBase.asmdef b/Samples~/KnowledgeBaseGame/KnowledgeBase.asmdef deleted file mode 100644 index e8831c36..00000000 --- a/Samples~/KnowledgeBaseGame/KnowledgeBase.asmdef +++ /dev/null @@ -1,28 +0,0 @@ -{ - "name": "KnowledgeBase", - "rootNamespace": "", - "references": [ - "undream.llmunity.Runtime", - "undream.RAGSearchUnity.Runtime", - "Unity.Sentis", - "HuggingFace.SharpTransformers", - "Cloud.Unum.USearch" - ], - "includePlatforms": [], - "excludePlatforms": [], - "allowUnsafeCode": false, - "overrideReferences": false, - "precompiledReferences": [], - "autoReferenced": true, - "defineConstraints": [ - "RAGSEARCHUNITY" - ], - "versionDefines": [ - { - "name": "ai.undream.ragsearchunity", - "expression": "1.0.0", - "define": "RAGSEARCHUNITY" - } - ], - "noEngineReferences": false -} \ No newline at end of file diff --git a/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs b/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs index e6f3bc29..a9649089 100644 --- a/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs +++ b/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs @@ -5,8 +5,8 @@ using Debug = UnityEngine.Debug; using UnityEngine.UI; using System.Collections; -using RAGSearchUnity; using LLMUnity; +using System.Threading.Tasks; namespace LLMUnitySamples { @@ -15,10 +15,10 @@ public class Bot Dictionary dialogues; SearchEngine search; - public Bot(string dialogueText, EmbeddingModel model, string embeddingsPath) + public Bot(string dialogueText, LLM llm, string embeddingsPath) { LoadDialogues(dialogueText); - CreateEmbeddings(model, embeddingsPath); + CreateEmbeddings(llm, embeddingsPath); } void LoadDialogues(string dialogueText) @@ -32,17 +32,17 @@ void LoadDialogues(string dialogueText) } } - void CreateEmbeddings(EmbeddingModel model, string embeddingsPath) + void CreateEmbeddings(LLM llm, string embeddingsPath) { if (File.Exists(embeddingsPath)) { // load the embeddings - search = SearchEngine.Load(model, embeddingsPath); + search = SearchEngine.Load(llm, embeddingsPath); } else { #if UNITY_EDITOR - search = new SearchEngine(model); + search = new SearchEngine(llm); Stopwatch stopwatch = new Stopwatch(); stopwatch.Start(); // build the embeddings @@ -60,9 +60,9 @@ void CreateEmbeddings(EmbeddingModel model, string embeddingsPath) } } - public List Retrieval(string question) + public async Task> Retrieval(string question) { - string[] similarQuestions = search.Search(question, 3); + string[] similarQuestions = await search.Search(question, 3); List similarAnswers = new List(); foreach (string similarQuestion in similarQuestions) similarAnswers.Add(dialogues[similarQuestion]); return similarAnswers; @@ -77,7 +77,7 @@ public int NumPhrases() public class KnowledgeBaseGame : KnowledgeBaseGameUI { public LLMCharacter llmCharacter; - public Embedding embedding; + public LLM llmEmbedding; Dictionary bots = new Dictionary(); Dictionary botImages = new Dictionary(); @@ -106,14 +106,13 @@ IEnumerator InitModels() if (!Directory.Exists(outputDir)) Directory.CreateDirectory(outputDir); // init the bots with the embeddings - EmbeddingModel model = embedding.GetModel(); - if (model == null) throw new System.Exception("Please select a model in the Embedding GameObject!"); + if (llmEmbedding == null) throw new System.Exception("Please select a model in the Embedding GameObject!"); foreach ((string botName, TextAsset asset, RawImage image) in botInfo) { string embeddingsPath = Path.Combine(outputDir, botName + ".zip"); PlayerText.text += File.Exists(embeddingsPath) ? $"Loading {botName} dialogues...\n" : $"Creating Embeddings for {botName} (only once)...\n"; yield return null; - bots[botName] = new Bot(asset.text, model, embeddingsPath); + bots[botName] = new Bot(asset.text, llmEmbedding, embeddingsPath); botImages[botName] = image; } @@ -125,10 +124,10 @@ IEnumerator InitModels() yield return null; } - protected override void OnInputFieldSubmit(string question) + protected async override void OnInputFieldSubmit(string question) { PlayerText.interactable = false; - List similarAnswers = currentBot.Retrieval(question); + List similarAnswers = await currentBot.Retrieval(question); string prompt = $"Question:\n{question}\n\nAnswers:\n"; foreach (string similarAnswer in similarAnswers) prompt += $"- {similarAnswer}\n"; _ = llmCharacter.Chat(prompt, SetAIText, AIReplyComplete); @@ -177,7 +176,7 @@ void OnValidate() { if (onValidateWarning) { - if (embedding.SelectedOption == 0) Debug.LogWarning($"Please select a model in the {embedding.gameObject.name} GameObject!"); + if (llmEmbedding == null) Debug.LogWarning($"Please select a llmEmbedding model in the {gameObject.name} GameObject!"); if (!llmCharacter.remote && llmCharacter.llm != null && llmCharacter.llm.model == "") Debug.LogWarning($"Please select a model in the {llmCharacter.llm.gameObject.name} GameObject!"); onValidateWarning = false; } @@ -188,7 +187,6 @@ public class KnowledgeBaseGameUI : MonoBehaviour { public Dropdown CharacterSelect; public InputField PlayerText; - public Text SetupText; public Text AIText; public TextAsset ButlerText; @@ -216,11 +214,6 @@ public class KnowledgeBaseGameUI : MonoBehaviour public Dropdown Answer2; public Dropdown Answer3; - void Awake() - { - if (SetupText != null) SetupText.gameObject.SetActive(false); - } - protected void Start() { AddListeners(); diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index 747e8af6..f6346435 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -71,7 +71,6 @@ public void TestLoras() public class TestLLM { - protected static string modelUrl = "https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/resolve/main/qwen2-0_5b-instruct-q4_k_m.gguf?download=true"; protected string modelNameLLManager; protected GameObject gameObject; @@ -120,9 +119,14 @@ public virtual void SetParameters() tokens2 = 9; } + protected virtual string GetModelUrl() + { + return "https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/resolve/main/qwen2-0_5b-instruct-q4_k_m.gguf?download=true"; + } + public virtual async Task DownloadModels() { - modelNameLLManager = await LLMManager.DownloadModel(modelUrl); + modelNameLLManager = await LLMManager.DownloadModel(GetModelUrl()); } [Test] @@ -292,7 +296,7 @@ public class TestLLM_LLMManager_Load : TestLLM public override LLM CreateLLM() { LLM llm = gameObject.AddComponent(); - string filename = Path.GetFileName(modelUrl).Split("?")[0]; + string filename = Path.GetFileName(GetModelUrl()).Split("?")[0]; string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); filename = LLMManager.LoadModel(sourcePath); llm.SetModel(filename); @@ -308,7 +312,7 @@ public class TestLLM_StreamingAssets_Load : TestLLM public override LLM CreateLLM() { LLM llm = gameObject.AddComponent(); - string filename = Path.GetFileName(modelUrl).Split("?")[0]; + string filename = Path.GetFileName(GetModelUrl()).Split("?")[0]; string sourcePath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); loadPath = LLMUnitySetup.GetAssetPath(filename); if (!File.Exists(loadPath)) File.Copy(sourcePath, loadPath); @@ -328,7 +332,7 @@ public class TestLLM_SetModel_Warning : TestLLM public override LLM CreateLLM() { LLM llm = gameObject.AddComponent(); - string filename = Path.GetFileName(modelUrl).Split("?")[0]; + string filename = Path.GetFileName(GetModelUrl()).Split("?")[0]; string loadPath = Path.Combine(LLMUnitySetup.modelDownloadPath, filename); llm.SetModel(loadPath); llm.parallelPrompts = 1; @@ -400,7 +404,7 @@ public override async Task Tests() public void TestModelPaths() { - Assert.AreEqual(llm.model, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(modelUrl).Split("?")[0]).Replace('\\', '/')); + Assert.AreEqual(llm.model, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(GetModelUrl()).Split("?")[0]).Replace('\\', '/')); Assert.AreEqual(llm.lora, Path.Combine(LLMUnitySetup.modelDownloadPath, Path.GetFileName(loraUrl).Split("?")[0]).Replace('\\', '/')); } diff --git a/Tests/Runtime/TestSearch.cs b/Tests/Runtime/TestSearch.cs new file mode 100644 index 00000000..6aca5e06 --- /dev/null +++ b/Tests/Runtime/TestSearch.cs @@ -0,0 +1,215 @@ +using NUnit.Framework; +using System.IO; +using System.Threading.Tasks; +using UnityEngine; +using LLMUnity; +using System; +using UnityEngine.TestTools; +using System.Collections; + +namespace LLMUnityTests +{ + public class TestSimpleSearch + { + string weather = "how is the weather today?"; + string raining = "is it raining?"; + string random = "something completely random"; + + protected string modelNameLLManager; + + protected GameObject gameObject; + protected LLM llm; + public SearchMethod search; + protected Exception error = null; + + public TestSimpleSearch() + { + Task task = Init(); + task.Wait(); + } + + public virtual async Task Init() + { + await DownloadModels(); + gameObject = new GameObject(); + gameObject.SetActive(false); + llm = CreateLLM(); + search = CreateSearch(); + gameObject.SetActive(true); + } + + public virtual LLM CreateLLM() + { + LLM llm = gameObject.AddComponent(); + llm.SetModel(modelNameLLManager); + llm.parallelPrompts = 1; + return llm; + } + + public virtual SearchMethod CreateSearch() + { + SimpleSearch search = gameObject.AddComponent(); + search.llm = llm; + search.stream = false; + return search; + } + + public virtual async Task DownloadModels() + { + modelNameLLManager = await LLMManager.DownloadModel(GetModelUrl()); + } + + protected string GetModelUrl() + { + return "https://huggingface.co/CompendiumLabs/bge-small-en-v1.5-gguf/resolve/main/bge-small-en-v1.5-f16.gguf"; + } + + public static bool ApproxEqual(float x1, float x2, float tolerance = 0.0001f) + { + return Mathf.Abs(x1 - x2) < tolerance; + } + + [UnityTest] + public IEnumerator RunTests() + { + Task task = RunTestsTask(); + while (!task.IsCompleted) yield return null; + if (error != null) + { + Debug.LogError(error.ToString()); + throw (error); + } + OnDestroy(); + } + + public async Task RunTestsTask() + { + error = null; + try + { + await Tests(); + llm.OnDestroy(); + } + catch (Exception e) + { + error = e; + } + } + + public virtual void OnDestroy() {} + + + public async Task Tests() + { + await TestEncode(); + await TestSimilarity(); + await TestAdd(); + await TestSearch(); + await TestIncrementalSearch(); + } + + public async Task TestEncode() + { + float[] encoding = await search.Encode(weather); + Assert.That(ApproxEqual(encoding[0], -0.02910374f)); + Assert.That(ApproxEqual(encoding[383], 0.01764517f)); + } + + public async Task TestSimilarity() + { + float[] sentence1 = await search.Encode(weather); + float[] sentence2 = await search.Encode(raining); + float trueSimilarity = 0.7926437f; + float similarity = SimpleSearch.DotProduct(sentence1, sentence2); + float distance = SimpleSearch.InverseDotProduct(sentence1, sentence2); + Assert.That(ApproxEqual(similarity, trueSimilarity)); + Assert.That(ApproxEqual(distance, 1 - trueSimilarity)); + } + + public async Task TestAdd() + { + int key = await search.Add(weather); + Assert.That(search.Get(key) == weather); + Assert.That(search.Count() == 1); + search.Remove(key); + Assert.That(search.Count() == 0); + + await search.Add(weather); + await search.Add(raining); + await search.Add(random); + Assert.That(search.Count() == 3); + search.Clear(); + Assert.That(search.Count() == 0); + } + + public async Task TestSearch() + { + await search.Add(weather); + await search.Add(raining); + await search.Add(random); + + string[] result = await search.Search(weather, 2); + Assert.AreEqual(result[0], weather); + Assert.AreEqual(result[1], raining); + + float[] encoding = await search.Encode(weather); + result = search.Search(encoding, 2); + Assert.AreEqual(result[0], weather); + Assert.AreEqual(result[1], raining); + + search.Clear(); + } + + public async Task TestIncrementalSearch() + { + await search.Add(weather); + await search.Add(raining); + await search.Add(random); + + int searchKey = await search.IncrementalSearch(weather); + string[] results; + bool completed; + (results, completed) = search.IncrementalFetch(searchKey, 1); + Assert.That(results.Length == 1); + Assert.AreEqual(results[0], weather); + Assert.That(!completed); + + (results, completed) = search.IncrementalFetch(searchKey, 2); + Assert.That(results.Length == 2); + Assert.AreEqual(results[0], raining); + Assert.AreEqual(results[1], random); + Assert.That(completed); + + searchKey = await search.IncrementalSearch(weather); + (results, completed) = search.IncrementalFetch(searchKey, 2); + Assert.That(results.Length == 2); + Assert.AreEqual(results[0], weather); + Assert.AreEqual(results[1], raining); + Assert.That(!completed); + + search.IncrementalSearchComplete(searchKey); + search.Clear(); + } + + public async Task TestSave() + { + string path = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + + await search.Add(weather); + await search.Add(raining); + await search.Add(random); + search.Save(path); + } + } + + public class TestDBSearch : TestSimpleSearch + { + public override SearchMethod CreateSearch() + { + DBSearch search = gameObject.AddComponent(); + search.llm = llm; + search.stream = false; + return search; + } + } +} diff --git a/Tests/Runtime/TestSearch.cs.meta b/Tests/Runtime/TestSearch.cs.meta new file mode 100644 index 00000000..e59b5748 --- /dev/null +++ b/Tests/Runtime/TestSearch.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a180f8b149e7fa4b2bf3ee216dd1a261 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Third Party Notices.md b/Third Party Notices.md index 3daa05a2..5606bf92 100644 --- a/Third Party Notices.md +++ b/Third Party Notices.md @@ -106,6 +106,72 @@ Origin: [link](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF)
License Type: "Apache 2.0"
License: [link](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct-GGUF/blob/main/LICENSE) +
+ +### BAAI/bge-large-en-v1.5 + +Developer: BAAI
+Origin: [link](https://huggingface.co/BAAI/bge-large-en-v1.5)
+License Type: "MIT"
+License: [link](https://huggingface.co/BAAI/bge-large-en-v1.5) + +##### modified by: CompendiumLabs/bge-large-en-v1.5-gguf + +Developer: Compendium Labs
+Origin: [link](https://huggingface.co/CompendiumLabs/bge-large-en-v1.5-gguf)
+License Type: "MIT"
+License: [link](https://huggingface.co/CompendiumLabs/bge-large-en-v1.5-gguf) + +
+ +### BAAI/bge-base-en-v1.5 + +Developer: BAAI
+Origin: [link](https://huggingface.co/BAAI/bge-base-en-v1.5)
+License Type: "MIT"
+License: [link](https://huggingface.co/BAAI/bge-base-en-v1.5) + +##### modified by: CompendiumLabs/bge-base-en-v1.5-gguf + +Developer: Compendium Labs
+Origin: [link](https://huggingface.co/CompendiumLabs/bge-base-en-v1.5-gguf)
+License Type: "MIT"
+License: [link](https://huggingface.co/CompendiumLabs/bge-base-en-v1.5-gguf) + +
+ +### BAAI/bge-small-en-v1.5 + +Developer: BAAI
+Origin: [link](https://huggingface.co/BAAI/bge-small-en-v1.5)
+License Type: "MIT"
+License: [link](https://huggingface.co/BAAI/bge-small-en-v1.5) + +##### modified by: CompendiumLabs/bge-small-en-v1.5-gguf + +Developer: Compendium Labs
+Origin: [link](https://huggingface.co/CompendiumLabs/bge-small-en-v1.5-gguf)
+License Type: "MIT"
+License: [link](https://huggingface.co/CompendiumLabs/bge-small-en-v1.5-gguf) + +
+ +### sentence-transformers/all-MiniLM-L12-v2 + +Developer: Sentence Transformers
+Origin: [link](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2)
+License Type: "Apache 2.0"
+License: [link](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2) + +##### modified by: leliuga/all-MiniLM-L12-v2-GGUF + +Developer: Leliuga
+Origin: [link](https://huggingface.co/leliuga/all-MiniLM-L12-v2-GGUF)
+License Type: "Apache 2.0"
+License: [link](https://huggingface.co/leliuga/all-MiniLM-L12-v2-GGUF) + +
+ --- ## Testing diff --git a/ThirdParty.meta b/ThirdParty.meta new file mode 100644 index 00000000..61b2f4b6 --- /dev/null +++ b/ThirdParty.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: cdaa3e55ae35be443aa985e442e3f4ff +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/ThirdParty/usearch.meta b/ThirdParty/usearch.meta new file mode 100644 index 00000000..a2a8b83c --- /dev/null +++ b/ThirdParty/usearch.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 47164dbfa5f37148a85f10f600e52819 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/ThirdParty/usearch/Cloud.Unum.USearch.asmdef b/ThirdParty/usearch/Cloud.Unum.USearch.asmdef new file mode 100644 index 00000000..adf4461e --- /dev/null +++ b/ThirdParty/usearch/Cloud.Unum.USearch.asmdef @@ -0,0 +1,3 @@ +{ + "name": "Cloud.Unum.USearch" +} diff --git a/Samples~/KnowledgeBaseGame/KnowledgeBase.asmdef.meta b/ThirdParty/usearch/Cloud.Unum.USearch.asmdef.meta similarity index 76% rename from Samples~/KnowledgeBaseGame/KnowledgeBase.asmdef.meta rename to ThirdParty/usearch/Cloud.Unum.USearch.asmdef.meta index 35a5d9b7..eebc7863 100644 --- a/Samples~/KnowledgeBaseGame/KnowledgeBase.asmdef.meta +++ b/ThirdParty/usearch/Cloud.Unum.USearch.asmdef.meta @@ -1,5 +1,5 @@ fileFormatVersion: 2 -guid: 42f2ee663135398278c988534b3ae0b3 +guid: 5cd85a883a79084e79040e24753a27a4 AssemblyDefinitionImporter: externalObjects: {} userData: diff --git a/ThirdParty/usearch/LICENSE b/ThirdParty/usearch/LICENSE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/ThirdParty/usearch/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/ThirdParty/usearch/LICENSE.meta b/ThirdParty/usearch/LICENSE.meta new file mode 100644 index 00000000..eea86eee --- /dev/null +++ b/ThirdParty/usearch/LICENSE.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: 974cd1ee69b60ce01a6746e41ac03d19 +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/ThirdParty/usearch/NativeMethods.cs b/ThirdParty/usearch/NativeMethods.cs new file mode 100644 index 00000000..f94bbe45 --- /dev/null +++ b/ThirdParty/usearch/NativeMethods.cs @@ -0,0 +1,158 @@ +using System.Runtime.InteropServices; + +using usearch_index_t = System.IntPtr; +using usearch_key_t = System.UInt64; +using usearch_error_t = System.IntPtr; +using size_t = System.UIntPtr; +using void_ptr_t = System.IntPtr; +using usearch_distance_t = System.Single; + +namespace Cloud.Unum.USearch +{ + public static class NativeMethodsHelpers + { + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + public delegate int FilterCallback(int key, void_ptr_t filterState); + } + + internal static class NativeMethods + { + private const string LibraryName = "libusearch_c"; + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern usearch_index_t usearch_init(ref IndexOptions options, out usearch_error_t error); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void usearch_free(usearch_index_t index, out usearch_error_t error); + + [DllImport(LibraryName, CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] + public static extern void usearch_save(usearch_index_t index, [MarshalAs(UnmanagedType.LPStr)] string path, out usearch_error_t error); + + [DllImport(LibraryName, CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] + public static extern void usearch_load(usearch_index_t index, [MarshalAs(UnmanagedType.LPStr)] string path, out usearch_error_t error); + + [DllImport(LibraryName, CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] + public static extern void usearch_view(usearch_index_t index, [MarshalAs(UnmanagedType.LPStr)] string path, out usearch_error_t error); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern size_t usearch_size(usearch_index_t index, out usearch_error_t error); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern size_t usearch_capacity(usearch_index_t index, out usearch_error_t error); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern size_t usearch_dimensions(usearch_index_t index, out usearch_error_t error); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern size_t usearch_connectivity(usearch_index_t index, out usearch_error_t error); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void usearch_reserve(usearch_index_t index, size_t capacity, out usearch_error_t error); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void usearch_add( + usearch_index_t index, + usearch_key_t key, + [In] float[] vector, + ScalarKind vector_kind, + out usearch_error_t error + ); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void usearch_add( + usearch_index_t index, + usearch_key_t key, + [In] double[] vector, + ScalarKind vector_kind, + out usearch_error_t error + ); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + [return : MarshalAs(UnmanagedType.I1)] + public static extern bool usearch_contains(usearch_index_t index, usearch_key_t key, out usearch_error_t error); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern size_t usearch_count(usearch_index_t index, usearch_key_t key, out usearch_error_t error); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern size_t usearch_search( + usearch_index_t index, + void_ptr_t query_vector, + ScalarKind query_kind, + size_t count, + [Out] usearch_key_t[] found_keys, + [Out] usearch_distance_t[] found_distances, + out usearch_error_t error + ); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern size_t usearch_search( + usearch_index_t index, + [In] float[] query_vector, + ScalarKind query_kind, + size_t count, + [Out] usearch_key_t[] found_keys, + [Out] usearch_distance_t[] found_distances, + out usearch_error_t error + ); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern size_t usearch_search( + usearch_index_t index, + [In] double[] query_vector, + ScalarKind query_kind, + size_t count, + [Out] usearch_key_t[] found_keys, + [Out] usearch_distance_t[] found_distances, + out usearch_error_t error + ); + + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern size_t usearch_get( + usearch_index_t index, + usearch_key_t key, + size_t count, + [Out] float[] vector, + ScalarKind vector_kind, + out usearch_error_t error + ); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern size_t usearch_get( + usearch_index_t index, + usearch_key_t key, + size_t count, + [Out] double[] vector, + ScalarKind vector_kind, + out usearch_error_t error + ); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern size_t usearch_remove(usearch_index_t index, usearch_key_t key, out usearch_error_t error); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern size_t usearch_rename(usearch_index_t index, usearch_key_t key_from, usearch_key_t key_to, out usearch_error_t error); + + //========================== Additional methods from LLMUnity ==========================// + + [DllImport(LibraryName, CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] + public static extern void usearch_load_buffer(usearch_index_t index, void_ptr_t buffer, size_t length, out usearch_error_t error); + + [DllImport(LibraryName, CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] + public static extern void usearch_view_buffer(usearch_index_t index, void_ptr_t buffer, size_t length, out usearch_error_t error); + + [DllImport(LibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern size_t usearch_filtered_search( + usearch_index_t index, + void_ptr_t query_vector, + ScalarKind query_kind, + size_t count, + NativeMethodsHelpers.FilterCallback filter, + void_ptr_t filterState, + [Out] usearch_key_t[] found_keys, + [Out] usearch_distance_t[] found_distances, + out usearch_error_t error + ); + } +} diff --git a/ThirdParty/usearch/NativeMethods.cs.meta b/ThirdParty/usearch/NativeMethods.cs.meta new file mode 100644 index 00000000..ba483ac8 --- /dev/null +++ b/ThirdParty/usearch/NativeMethods.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: eec2de19e997c7c9e81c4c1f3dd6b78f +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/ThirdParty/usearch/README.md b/ThirdParty/usearch/README.md new file mode 100644 index 00000000..ea420ed2 --- /dev/null +++ b/ThirdParty/usearch/README.md @@ -0,0 +1,468 @@ +# USearch +## Description +USearch is a Similarity Search Engine. +The USearch repository can be found here: https://github.com/unum-cloud/usearch/ + +## License +USearch is distributed under Apache-2.0 license, see the [LICENSE file](LICENSE). + +## Local Modifications +The files from the C# code here: https://github.com/unum-cloud/usearch/tree/main/csharp/src/Cloud.Unum.USearch +have been copied inside this folder. + +Additionally, the runtime dlls have been exported from the nuget package downloaded from here: +https://www.nuget.org/packages/Cloud.Unum.USearch +and added in the x86_64 folder. + +The following modifications have been applied +- NativeMethods.cs: the usearch_view_buffer and usearch_load_buffer are being imported +- USearchIndex.cs: the save and load functions have been modified to save/load the IndexOptions along with the index in a ZipArchive +- .meta files have been automatically added by Unity +- a Cloud.Unum.USearch.asmdef file has been added to denote the name of the package +- the dll file inside the arm64 has been compiled for arm64 macOS from the USearch code repository. + +## Original Readme +The original readme is appended here: + +----- + +

USearch

+

+Smaller & Faster Single-File
+Similarity Search Engine for Vectors & πŸ”œ Texts +

+
+ +

+Discord +    +LinkedIn +    +Twitter +    +Blog +    +GitHub +

+ +

+Spatial β€’ Binary β€’ Probabilistic β€’ User-Defined Metrics +
+C++ 11 β€’ +Python 3 β€’ +JavaScript β€’ +Java β€’ +Rust β€’ +C 99 β€’ +Objective-C β€’ +Swift β€’ +C# β€’ +GoLang β€’ +Wolfram +
+Linux β€’ MacOS β€’ Windows β€’ iOS β€’ WebAssembly +

+ +
+ PyPI + NPM + Crate + NuGet + Maven + Docker +GitHub code size in bytes +
+ +--- + +- βœ… [10x faster][faster-than-faiss] [HNSW][hnsw-algorithm] implementation than [FAISS][faiss]. +- βœ… Simple and extensible [single C++11 header][usearch-header] implementation. +- βœ… Compatible with a dozen programming languages out of the box. +- βœ… [Trusted](#integrations) by some of the most loved Datalakes and Databases, like [ClickHouse][clickhouse-docs]. +- βœ… [SIMD][simd]-optimized and [user-defined metrics](#user-defined-functions) with JIT compilation. +- βœ… Hardware-agnostic `f16` & `i8` - [half-precision & quarter-precision support](#memory-efficiency-downcasting-and-quantization). +- βœ… [View large indexes from disk](#serving-index-from-disk) without loading into RAM. +- βœ… Heterogeneous lookups, renaming/relabeling, and on-the-fly deletions. +- βœ… Variable dimensionality vectors for unique applications, including search over compressed data. +- βœ… Binary Tanimoto and Sorensen coefficients for [Genomics and Chemistry applications](#usearch--rdkit--molecular-search). +- βœ… Space-efficient point-clouds with `uint40_t`, accommodating 4B+ size. +- βœ… Compatible with OpenMP and custom "executors" for fine-grained control over CPU utilization. +- βœ… Near-real-time [clustering and sub-clustering](#clustering) for Tens or Millions of clusters. +- βœ… [Semantic Search](#usearch--ai--multi-modal-semantic-search) and [Joins](#joins). + +[faiss]: https://github.com/facebookresearch/faiss +[usearch-header]: https://github.com/unum-cloud/usearch/blob/main/include/usearch/index.hpp +[obscure-use-cases]: https://ashvardanian.com/posts/abusing-vector-search +[hnsw-algorithm]: https://arxiv.org/abs/1603.09320 +[simd]: https://en.wikipedia.org/wiki/Single_instruction,_multiple_data +[faster-than-faiss]: https://www.unum.cloud/blog/2023-11-07-scaling-vector-search-with-intel +[clickhouse-docs]: https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/annindexes#usearch + +__Technical Insights__ and related articles: + +- [Uses Horner's method for polynomial approximations, beating GCC 12 by 119x](https://ashvardanian.com/posts/gcc-12-vs-avx512fp16/). +- [Uses Arm SVE and x86 AVX-512's masked loads to eliminate tail `for`-loops](https://ashvardanian.com/posts/simsimd-faster-scipy/#tails-of-the-past-the-significance-of-masked-loads). +- [Uses AVX-512 FP16 for half-precision operations, that few compilers vectorize](https://ashvardanian.com/posts/simsimd-faster-scipy/#the-challenge-of-f16). +- [Substitutes LibC's `sqrt` calls with bithacks using Jan Kadlec's constant](https://ashvardanian.com/posts/simsimd-faster-scipy/#bonus-section-bypassing-sqrt-and-libc-dependencies). +- [For every language implements a custom separate binding](https://ashvardanian.com/posts/porting-cpp-library-to-ten-languages/). +- [For Python avoids slow PyBind11, and even `PyArg_ParseTuple` for speed](https://ashvardanian.com/posts/pybind11-cpython-tutorial/). +- [For JavaScript uses typed arrays and NAPI for zero-copy calls](https://ashvardanian.com/posts/javascript-ai-vector-search/). + +## Comparison with FAISS + +FAISS is a widely recognized standard for high-performance vector search engines. +USearch and FAISS both employ the same HNSW algorithm, but they differ significantly in their design principles. +USearch is compact and broadly compatible without sacrificing performance, primarily focusing on user-defined metrics and fewer dependencies. + +| | FAISS | USearch | Improvement | +| :------------------------------------------- | ----------------------: | -----------------------: | ----------------------: | +| Indexing time ⁰ | | | | +| 100 Million 96d `f32`, `f16`, `i8` vectors | 2.6 Β· 2.6 Β· 2.6 h | 0.3 Β· 0.2 Β· 0.2 h | __9.6 Β· 10.4 Β· 10.7 x__ | +| 100 Million 1536d `f32`, `f16`, `i8` vectors | 5.0 Β· 4.1 Β· 3.8 h | 2.1 Β· 1.1 Β· 0.8 h | __2.3 Β· 3.6 Β· 4.4 x__ | +| | | | | +| Codebase length ΒΉ | 84 K [SLOC][sloc] | 3 K [SLOC][sloc] | maintainable | +| Supported metrics Β² | 9 fixed metrics | any metric | extendible | +| Supported languages Β³ | C++, Python | 10 languages | portable | +| Supported ID types ⁴ | 32-bit, 64-bit | 32-bit, 40-bit, 64-bit | efficient | +| Required dependencies ⁡ | BLAS, OpenMP | - | light-weight | +| Bindings ⁢ | SWIG | Native | low-latency | +| Python binding size ⁷ | [~ 10 MB][faiss-weight] | [< 1 MB][usearch-weight] | deployable | + +[sloc]: https://en.wikipedia.org/wiki/Source_lines_of_code +[faiss-weight]: https://pypi.org/project/faiss-cpu/#files +[usearch-weight]: https://pypi.org/project/usearch/#files + +> ⁰ [Tested][intel-benchmarks] on Intel Sapphire Rapids, with the simplest inner-product distance, equivalent recall, and memory consumption while also providing far superior search speed. +> ΒΉ A shorter codebase of `usearch/` over `faiss/` makes the project easier to maintain and audit. +> Β² User-defined metrics allow you to customize your search for various applications, from GIS to creating custom metrics for composite embeddings from multiple AI models or hybrid full-text and semantic search. +> Β³ With USearch, you can reuse the same preconstructed index in various programming languages. +> ⁴ The 40-bit integer allows you to store 4B+ vectors without allocating 8 bytes for every neighbor reference in the proximity graph. +> ⁡ Lack of obligatory dependencies makes USearch much more portable. +> ⁢ Native bindings introduce lower call latencies than more straightforward approaches. +> ⁷ Lighter bindings make downloads and deployments faster. + +[intel-benchmarks]: https://www.unum.cloud/blog/2023-11-07-scaling-vector-search-with-intel + +Base functionality is identical to FAISS, and the interface must be familiar if you have ever investigated Approximate Nearest Neighbors search: + +```py +$ pip install numpy usearch + +import numpy as np +from usearch.index import Index + +index = Index(ndim=3) + +vector = np.array([0.2, 0.6, 0.4]) +index.add(42, vector) + +matches = index.search(vector, 10) + +assert matches[0].key == 42 +assert matches[0].distance <= 0.001 +assert np.allclose(index[42], vector) +``` + +More settings are always available, and the API is designed to be as flexible as possible. + +```py +index = Index( + ndim=3, # Define the number of dimensions in input vectors + metric='cos', # Choose 'l2sq', 'haversine' or other metric, default = 'ip' + dtype='f32', # Quantize to 'f16' or 'i8' if needed, default = 'f32' + connectivity=16, # Optional: Limit number of neighbors per graph node + expansion_add=128, # Optional: Control the recall of indexing + expansion_search=64, # Optional: Control the quality of the search + multi=False, # Optional: Allow multiple vectors per key, default = False +) +``` + +## Serialization & Serving `Index` from Disk + +USearch supports multiple forms of serialization: + +- Into a __file__ defined with a path. +- Into a __stream__ defined with a callback, serializing or reconstructing incrementally. +- Into a __buffer__ of fixed length or a memory-mapped file that supports random access. + +The latter allows you to serve indexes from external memory, enabling you to optimize your server choices for indexing speed and serving costs. +This can result in __20x cost reduction__ on AWS and other public clouds. + +```py +index.save("index.usearch") + +loaded_copy = index.load("index.usearch") +view = Index.restore("index.usearch", view=True) + +other_view = Index(ndim=..., metric=...) +other_view.view("index.usearch") +``` + +## Exact vs. Approximate Search + +Approximate search methods, such as HNSW, are predominantly used when an exact brute-force search becomes too resource-intensive. +This typically occurs when you have millions of entries in a collection. +For smaller collections, we offer a more direct approach with the `search` method. + +```py +from usearch.index import search, MetricKind, Matches, BatchMatches +import numpy as np + +# Generate 10'000 random vectors with 1024 dimensions +vectors = np.random.rand(10_000, 1024).astype(np.float32) +vector = np.random.rand(1024).astype(np.float32) + +one_in_many: Matches = search(vectors, vector, 50, MetricKind.L2sq, exact=True) +many_in_many: BatchMatches = search(vectors, vectors, 50, MetricKind.L2sq, exact=True) +``` + +If you pass the `exact=True` argument, the system bypasses indexing altogether and performs a brute-force search through the entire dataset using SIMD-optimized similarity metrics from [SimSIMD](https://github.com/ashvardanian/simsimd). +When compared to FAISS's `IndexFlatL2` in Google Colab, __[USearch may offer up to a 20x performance improvement](https://github.com/unum-cloud/usearch/issues/176#issuecomment-1666650778)__: + +- `faiss.IndexFlatL2`: __55.3 ms__. +- `usearch.index.search`: __2.54 ms__. + +## `Indexes` for Multi-Index Lookups + +For larger workloads targeting billions or even trillions of vectors, parallel multi-index lookups become invaluable. +Instead of constructing one extensive index, you can build multiple smaller ones and view them together. + +```py +from usearch.index import Indexes + +multi_index = Indexes( + indexes: Iterable[usearch.index.Index] = [...], + paths: Iterable[os.PathLike] = [...], + view: bool = False, + threads: int = 0, +) +multi_index.search(...) +``` + +## Clustering + +Once the index is constructed, USearch can perform K-Nearest Neighbors Clustering much faster than standalone clustering libraries, like SciPy, +UMap, and tSNE. +Same for dimensionality reduction with PCA. +Essentially, the `Index` itself can be seen as a clustering, allowing iterative deepening. + +```py +clustering = index.cluster( + min_count=10, # Optional + max_count=15, # Optional + threads=..., # Optional +) + +# Get the clusters and their sizes +centroid_keys, sizes = clustering.centroids_popularity + +# Use Matplotlib to draw a histogram +clustering.plot_centroids_popularity() + +# Export a NetworkX graph of the clusters +g = clustering.network + +# Get members of a specific cluster +first_members = clustering.members_of(centroid_keys[0]) + +# Deepen into that cluster, splitting it into more parts, all the same arguments supported +sub_clustering = clustering.subcluster(min_count=..., max_count=...) +``` + +The resulting clustering isn't identical to K-Means or other conventional approaches but serves the same purpose. +Alternatively, using Scikit-Learn on a 1 Million point dataset, one may expect queries to take anywhere from minutes to hours, depending on the number of clusters you want to highlight. +For 50'000 clusters, the performance difference between USearch and conventional clustering methods may easily reach 100x. + +## Joins, One-to-One, One-to-Many, and Many-to-Many Mappings + +One of the big questions these days is how AI will change the world of databases and data management. +Most databases are still struggling to implement high-quality fuzzy search, and the only kind of joins they know are deterministic. +A `join` differs from searching for every entry, requiring a one-to-one mapping banning collisions among separate search results. + +| Exact Search | Fuzzy Search | Semantic Search ? | +| :----------: | :----------: | :---------------: | +| Exact Join | Fuzzy Join ? | Semantic Join ?? | + +Using USearch, one can implement sub-quadratic complexity approximate, fuzzy, and semantic joins. +This can be useful in any fuzzy-matching tasks common to Database Management Software. + +```py +men = Index(...) +women = Index(...) +pairs: dict = men.join(women, max_proposals=0, exact=False) +``` + +> Read more in the post: [Combinatorial Stable Marriages for Semantic Search πŸ’](https://ashvardanian.com/posts/searching-stable-marriages) + +## User-Defined Functions + +While most vector search packages concentrate on just a few metrics - "Inner Product distance" and "Euclidean distance," USearch extends this list to include any user-defined metrics. +This flexibility allows you to customize your search for various applications, from computing geospatial coordinates with the rare [Haversine][haversine] distance to creating custom metrics for composite embeddings from multiple AI models. + +![USearch: Vector Search Approaches](https://github.com/unum-cloud/usearch/blob/main/assets/usearch-approaches-white.png?raw=true) + +Unlike older approaches indexing high-dimensional spaces, like KD-Trees and Locality Sensitive Hashing, HNSW doesn't require vectors to be identical in length. +They only have to be comparable. +So you can apply it in [obscure][obscure] applications, like searching for similar sets or fuzzy text matching, using [GZip][gzip-similarity] as a distance function. + +> Read more about [JIT and UDF in USearch Python SDK](https://unum-cloud.github.io/usearch/python#user-defined-metrics-and-jit-in-python). + +[haversine]: https://ashvardanian.com/posts/abusing-vector-search#geo-spatial-indexing +[obscure]: https://ashvardanian.com/posts/abusing-vector-search +[gzip-similarity]: https://twitter.com/LukeGessler/status/1679211291292889100?s=20 + +## Memory Efficiency, Downcasting, and Quantization + +Training a quantization model and dimension-reduction is a common approach to accelerate vector search. +Those, however, are only sometimes reliable, can significantly affect the statistical properties of your data, and require regular adjustments if your distribution shifts. +Instead, we have focused on high-precision arithmetic over low-precision downcasted vectors. +The same index, and `add` and `search` operations will automatically down-cast or up-cast between `f64_t`, `f32_t`, `f16_t`, `i8_t`, and single-bit representations. +You can use the following command to check, if hardware acceleration is enabled: + +```sh +$ python -c 'from usearch.index import Index; print(Index(ndim=768, metric="cos", dtype="f16").hardware_acceleration)' +> avx512+f16 +$ python -c 'from usearch.index import Index; print(Index(ndim=166, metric="tanimoto").hardware_acceleration)' +> avx512+popcnt +``` + +Using smaller numeric types will save you RAM needed to store the vectors, but you can also compress the neighbors lists forming our proximity graphs. +By default, 32-bit `uint32_t` is used to enumerate those, which is not enough if you need to address over 4 Billion entries. +For such cases we provide a custom `uint40_t` type, that will still be 37.5% more space-efficient than the commonly used 8-byte integers, and will scale up to 1 Trillion entries. + +![USearch uint40_t support](https://github.com/unum-cloud/usearch/blob/main/assets/usearch-neighbor-types.png?raw=true) + +## Functionality + +By now, the core functionality is supported across all bindings. +Broader functionality is ported per request. +In some cases, like Batch operations, feature parity is meaningless, as the host language has full multi-threading capabilities and the USearch index structure is concurrent by design, so the users can implement batching/scheduling/load-balancing in the most optimal way for their applications. + +| | C++ 11 | Python 3 | C 99 | Java | JavaScript | Rust | GoLang | Swift | +| :---------------------- | :----: | :------: | :---: | :---: | :--------: | :---: | :----: | :---: | +| Add, search, remove | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | +| Save, load, view | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | +| User-defined metrics | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | +| Batch operations | ❌ | βœ… | ❌ | ❌ | βœ… | ❌ | ❌ | ❌ | +| Joins | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| Variable-length vectors | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| 4B+ capacities | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | + +## Application Examples + +### USearch + AI = Multi-Modal Semantic Search + +[![USearch Semantic Image Search](https://github.com/ashvardanian/usearch-images/raw/main/assets/usearch-images-slow.gif)](https://github.com/ashvardanian/usearch-images) + +AI has a growing number of applications, but one of the coolest classic ideas is to use it for Semantic Search. +One can take an encoder model, like the multi-modal [UForm](https://github.com/unum-cloud/uform), and a web-programming framework, like UCall, and build a text-to-image search platform in just 20 lines of Python. + +```python +import ucall +import uform +import usearch + +import numpy as np +import PIL as pil + +server = ucall.Server() +model = uform.get_model('unum-cloud/uform-vl-multilingual') +index = usearch.index.Index(ndim=256) + +@server +def add(key: int, photo: pil.Image.Image): + image = model.preprocess_image(photo) + vector = model.encode_image(image).detach().numpy() + index.add(key, vector.flatten(), copy=True) + +@server +def search(query: str) -> np.ndarray: + tokens = model.preprocess_text(query) + vector = model.encode_text(tokens).detach().numpy() + matches = index.search(vector.flatten(), 3) + return matches.keys + +server.run() +``` + +A more complete [demo with Streamlit is available on GitHub](https://github.com/ashvardanian/usearch-images). +We have pre-processed some commonly used datasets, cleaned the images, produced the vectors, and pre-built the index. + +| Dataset | Modalities | Images | Download | +| :---------------------------------- | --------------------: | -----: | ------------------------------------: | +| [Unsplash][unsplash-25k-origin] | Images & Descriptions | 25 K | [HuggingFace / Unum][unsplash-25k-hf] | +| [Conceptual Captions][cc-3m-origin] | Images & Descriptions | 3 M | [HuggingFace / Unum][cc-3m-hf] | +| [Arxiv][arxiv-2m-origin] | Titles & Abstracts | 2 M | [HuggingFace / Unum][arxiv-2m-hf] | + +[unsplash-25k-origin]: https://github.com/unsplash/datasets +[cc-3m-origin]: https://huggingface.co/datasets/conceptual_captions +[arxiv-2m-origin]: https://www.kaggle.com/datasets/Cornell-University/arxiv + +[unsplash-25k-hf]: https://huggingface.co/datasets/unum-cloud/ann-unsplash-25k +[cc-3m-hf]: https://huggingface.co/datasets/unum-cloud/ann-cc-3m +[arxiv-2m-hf]: https://huggingface.co/datasets/unum-cloud/ann-arxiv-2m + +### USearch + RDKit = Molecular Search + +Comparing molecule graphs and searching for similar structures is expensive and slow. +It can be seen as a special case of the NP-Complete Subgraph Isomorphism problem. +Luckily, domain-specific approximate methods exist. +The one commonly used in Chemistry is to generate structures from [SMILES][smiles] and later hash them into binary fingerprints. +The latter are searchable with binary similarity metrics, like the Tanimoto coefficient. +Below is an example using the RDKit package. + +```python +from usearch.index import Index, MetricKind +from rdkit import Chem +from rdkit.Chem import AllChem + +import numpy as np + +molecules = [Chem.MolFromSmiles('CCOC'), Chem.MolFromSmiles('CCO')] +encoder = AllChem.GetRDKitFPGenerator() + +fingerprints = np.vstack([encoder.GetFingerprint(x) for x in molecules]) +fingerprints = np.packbits(fingerprints, axis=1) + +index = Index(ndim=2048, metric=MetricKind.Tanimoto) +keys = np.arange(len(molecules)) + +index.add(keys, fingerprints) +matches = index.search(fingerprints, 10) +``` + +That method was used to build the ["USearch Molecules"](https://github.com/ashvardanian/usearch-molecules), one of the largest Chem-Informatics datasets, containing 7 billion small molecules and 28 billion fingerprints. + +[smiles]: https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system +[rdkit-fingerprints]: https://www.rdkit.org/docs/RDKit_Book.html#additional-information-about-the-fingerprints + +### USearch + POI Coordinates = GIS Applications... on iOS? + +[![USearch Maps with SwiftUI](https://github.com/ashvardanian/SwiftVectorSearch/raw/main/USearch+SwiftUI.gif)](https://github.com/ashvardanian/SwiftVectorSearch) + +With Objective-C and Swift iOS bindings, USearch can be easily used in mobile applications. +The [SwiftVectorSearch](https://github.com/ashvardanian/SwiftVectorSearch) project illustrates how to build a dynamic, real-time search system on iOS. +In this example, we use 2-dimensional vectorsβ€”encoded as latitude and longitudeβ€”to find the closest Points of Interest (POIs) on a map. +The search is based on the Haversine distance metric but can easily be extended to support high-dimensional vectors. + +## Integrations + +- [x] GPTCache: [Python](https://github.com/zilliztech/GPTCache/releases/tag/0.1.29). +- [x] LangChain: [Python](https://github.com/langchain-ai/langchain/releases/tag/v0.0.257) and [JavaScript](https://github.com/hwchase17/langchainjs/releases/tag/0.0.125). +- [x] ClickHouse: [C++](https://github.com/ClickHouse/ClickHouse/pull/53447). +- [x] Microsoft Semantic Kernel: [Python](https://github.com/microsoft/semantic-kernel/releases/tag/python-0.3.9.dev) and C#. +- [x] LanternDB: [C++](https://github.com/lanterndata/lantern) and [Rust](https://github.com/lanterndata/lantern_extras). + +## Citations + +```txt +@software{Vardanian_USearch_2023, +doi = {10.5281/zenodo.7949416}, +author = {Vardanian, Ash}, +title = {{USearch by Unum Cloud}}, +url = {https://github.com/unum-cloud/usearch}, +version = {2.8.16}, +year = {2023}, +month = oct, +} +``` diff --git a/ThirdParty/usearch/README.md.meta b/ThirdParty/usearch/README.md.meta new file mode 100644 index 00000000..fa1c62fe --- /dev/null +++ b/ThirdParty/usearch/README.md.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: e64a900087d96cf7fbdefd8a24819eb8 +TextScriptImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/ThirdParty/usearch/USearchException.cs b/ThirdParty/usearch/USearchException.cs new file mode 100644 index 00000000..5f5f258d --- /dev/null +++ b/ThirdParty/usearch/USearchException.cs @@ -0,0 +1,9 @@ +using System; + +namespace Cloud.Unum.USearch +{ + public class USearchException : Exception + { + public USearchException(string message) : base(message) {} + } +} diff --git a/ThirdParty/usearch/USearchException.cs.meta b/ThirdParty/usearch/USearchException.cs.meta new file mode 100644 index 00000000..f3ff8a54 --- /dev/null +++ b/ThirdParty/usearch/USearchException.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 587fe86c17f98ed86af2b0adda713e91 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/ThirdParty/usearch/USearchIndex.cs b/ThirdParty/usearch/USearchIndex.cs new file mode 100644 index 00000000..3f87c4ed --- /dev/null +++ b/ThirdParty/usearch/USearchIndex.cs @@ -0,0 +1,398 @@ +using System; +using System.IO; +using System.IO.Compression; +using System.Runtime.InteropServices; +using UnityEngine; +using static Cloud.Unum.USearch.NativeMethods; + +namespace Cloud.Unum.USearch +{ + /// + /// USearchIndex class provides a managed wrapper for the USearch library's index functionality. + /// + public class USearchIndex : IDisposable + { + private IntPtr _index; + private bool _disposedValue = false; + private ulong _cachedDimensions; + + public USearchIndex( + MetricKind metricKind, + ScalarKind quantization, + ulong dimensions, + ulong connectivity = 0, + ulong expansionAdd = 0, + ulong expansionSearch = 0, + bool multi = false + //CustomDistanceFunction? customMetric = null + ) + { + IndexOptions initOptions = new() + { + metric_kind = metricKind, + metric = default, + quantization = quantization, + dimensions = dimensions, + connectivity = connectivity, + expansion_add = expansionAdd, + expansion_search = expansionSearch, + multi = multi + }; + + this._index = usearch_init(ref initOptions, out IntPtr error); + HandleError(error); + this._cachedDimensions = dimensions; + } + + public USearchIndex(IndexOptions options) + { + this._index = usearch_init(ref options, out IntPtr error); + HandleError(error); + this._cachedDimensions = options.dimensions; + } + + public USearchIndex(string path, bool view = false) + { + IndexOptions initOptions = new(); + this._index = usearch_init(ref initOptions, out IntPtr error); + HandleError(error); + + if (view) + { + usearch_view(this._index, path, out error); + } + else + { + usearch_load(this._index, path, out error); + } + + HandleError(error); + + this._cachedDimensions = this.Dimensions(); + } + + public void Save(string path) + { + usearch_save(this._index, path, out IntPtr error); + HandleError(error); + } + + public ulong Size() + { + ulong size = (ulong)usearch_size(this._index, out IntPtr error); + HandleError(error); + return size; + } + + public ulong Capacity() + { + ulong capacity = (ulong)usearch_capacity(this._index, out IntPtr error); + HandleError(error); + return capacity; + } + + public ulong Dimensions() + { + ulong dimensions = (ulong)usearch_dimensions(this._index, out IntPtr error); + HandleError(error); + return dimensions; + } + + public ulong Connectivity() + { + ulong connectivity = (ulong)usearch_connectivity(this._index, out IntPtr error); + HandleError(error); + return connectivity; + } + + public bool Contains(ulong key) + { + bool result = usearch_contains(this._index, key, out IntPtr error); + HandleError(error); + return result; + } + + public int Count(ulong key) + { + int count = checked((int)usearch_count(this._index, key, out IntPtr error)); + HandleError(error); + return count; + } + + private void IncreaseCapacity(ulong size) + { + usearch_reserve(this._index, (UIntPtr)(this.Size() + size), out IntPtr error); + HandleError(error); + } + + private void CheckIncreaseCapacity(ulong size_increase) + { + ulong size_demand = this.Size() + size_increase; + if (this.Capacity() < size_demand) + { + this.IncreaseCapacity(size_increase); + } + } + + public void Add(ulong key, float[] vector) + { + this.CheckIncreaseCapacity(1); + usearch_add(this._index, key, vector, ScalarKind.Float32, out IntPtr error); + HandleError(error); + } + + public void Add(ulong key, double[] vector) + { + this.CheckIncreaseCapacity(1); + usearch_add(this._index, key, vector, ScalarKind.Float64, out IntPtr error); + HandleError(error); + } + + public void Add(ulong[] keys, float[][] vectors) + { + this.CheckIncreaseCapacity((ulong)vectors.Length); + for (int i = 0; i < vectors.Length; i++) + { + usearch_add(this._index, keys[i], vectors[i], ScalarKind.Float32, out IntPtr error); + HandleError(error); + } + } + + public void Add(ulong[] keys, double[][] vectors) + { + this.CheckIncreaseCapacity((ulong)vectors.Length); + for (int i = 0; i < vectors.Length; i++) + { + usearch_add(this._index, keys[i], vectors[i], ScalarKind.Float64, out IntPtr error); + HandleError(error); + } + } + + public int Get(ulong key, out float[] vector) + { + vector = new float[this._cachedDimensions]; + int foundVectorsCount = checked((int)usearch_get(this._index, key, (UIntPtr)1, vector, ScalarKind.Float32, out IntPtr error)); + HandleError(error); + if (foundVectorsCount < 1) + { + vector = null; + } + + return foundVectorsCount; + } + + public int Get(ulong key, int count, out float[][] vectors) + { + var flattenVectors = new float[count * (int)this._cachedDimensions]; + int foundVectorsCount = checked((int)usearch_get(this._index, key, (UIntPtr)count, flattenVectors, ScalarKind.Float32, out IntPtr error)); + HandleError(error); + if (foundVectorsCount < 1) + { + vectors = null; + } + else + { + vectors = new float[foundVectorsCount][]; + for (int i = 0; i < foundVectorsCount; i++) + { + vectors[i] = new float[this._cachedDimensions]; + Array.Copy(flattenVectors, i * (int)this._cachedDimensions, vectors[i], 0, (int)this._cachedDimensions); + } + } + + return foundVectorsCount; + } + + public int Get(ulong key, out double[] vector) + { + vector = new double[this._cachedDimensions]; + int foundVectorsCount = checked((int)usearch_get(this._index, key, (UIntPtr)1, vector, ScalarKind.Float64, out IntPtr error)); + HandleError(error); + if (foundVectorsCount < 1) + { + vector = null; + } + + return foundVectorsCount; + } + + public int Get(ulong key, int count, out double[][] vectors) + { + var flattenVectors = new double[count * (int)this._cachedDimensions]; + int foundVectorsCount = checked((int)usearch_get(this._index, key, (UIntPtr)count, flattenVectors, ScalarKind.Float64, out IntPtr error)); + HandleError(error); + if (foundVectorsCount < 1) + { + vectors = null; + } + else + { + vectors = new double[foundVectorsCount][]; + for (int i = 0; i < foundVectorsCount; i++) + { + vectors[i] = new double[this._cachedDimensions]; + Array.Copy(flattenVectors, i * (int)this._cachedDimensions, vectors[i], 0, (int)this._cachedDimensions); + } + } + + return foundVectorsCount; + } + + public int Remove(ulong key) + { + int removedCount = checked((int)usearch_remove(this._index, key, out IntPtr error)); + HandleError(error); + return removedCount; + } + + public int Rename(ulong keyFrom, ulong keyTo) + { + int foundVectorsCount = checked((int)usearch_rename(this._index, keyFrom, keyTo, out IntPtr error)); + HandleError(error); + return foundVectorsCount; + } + + private static void HandleError(IntPtr error) + { + if (error != IntPtr.Zero) + { + throw new USearchException($"USearch operation failed: {Marshal.PtrToStringAnsi(error)}"); + } + } + + private void FreeIndex() + { + if (this._index != IntPtr.Zero) + { + usearch_free(this._index, out IntPtr error); + HandleError(error); + this._index = IntPtr.Zero; + } + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (!this._disposedValue) + { + this.FreeIndex(); + this._disposedValue = true; + } + } + + ~USearchIndex() => this.Dispose(false); + + //========================== Additional methods from LLMUnity ==========================// + + private int Search(T[] queryVector, int count, out ulong[] keys, out float[] distances, ScalarKind scalarKind, NativeMethodsHelpers.FilterCallback filter = null) + { + keys = new ulong[count]; + distances = new float[count]; + + GCHandle handle = GCHandle.Alloc(queryVector, GCHandleType.Pinned); + int matches = 0; + try + { + IntPtr queryVectorPtr = handle.AddrOfPinnedObject(); + IntPtr error; + if (filter == null) + { + matches = checked((int)usearch_search(this._index, queryVectorPtr, scalarKind, (UIntPtr)count, keys, distances, out error)); + } + else + { + matches = checked((int)usearch_filtered_search(this._index, queryVectorPtr, scalarKind, (UIntPtr)count, filter, IntPtr.Zero, keys, distances, out error)); + } + + // matches = checked((int)usearch_search(this._index, queryVectorPtr, scalarKind, (UIntPtr)count, keys, distances, out IntPtr error)); + HandleError(error); + } + finally + { + handle.Free(); + } + + if (matches < count) + { + Array.Resize(ref keys, (int)matches); + Array.Resize(ref distances, (int)matches); + } + + return matches; + } + + public int Search(float[] queryVector, int count, out ulong[] keys, out float[] distances, NativeMethodsHelpers.FilterCallback filter = null) + { + return this.Search(queryVector, count, out keys, out distances, ScalarKind.Float32, filter); + } + + public int Search(double[] queryVector, int count, out ulong[] keys, out float[] distances, NativeMethodsHelpers.FilterCallback filter = null) + { + return this.Search(queryVector, count, out keys, out distances, ScalarKind.Float64, filter); + } + + protected virtual string GetIndexFilename() + { + return "index"; + } + + public void Save(ZipArchive zipArchive) + { + string indexPath = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + Save(indexPath); + try + { + zipArchive.CreateEntryFromFile(indexPath, GetIndexFilename()); + } + catch (Exception ex) + { + Debug.LogError($"Error adding file to the zip archive: {ex.Message}"); + } + File.Delete(indexPath); + } + + public void Load(ZipArchive zipArchive) + { + IndexOptions initOptions = new(); + this._index = usearch_init(ref initOptions, out IntPtr error); + HandleError(error); + + try + { + ZipArchiveEntry entry = zipArchive.GetEntry(GetIndexFilename()); + using (Stream entryStream = entry.Open()) + using (MemoryStream memoryStream = new MemoryStream()) + { + entryStream.CopyTo(memoryStream); + // Access the length and create a buffer + byte[] managedBuffer = new byte[memoryStream.Length]; + memoryStream.Position = 0; // Reset the position to the beginning + memoryStream.Read(managedBuffer, 0, managedBuffer.Length); + + GCHandle handle = GCHandle.Alloc(managedBuffer, GCHandleType.Pinned); + try + { + IntPtr unmanagedBuffer = handle.AddrOfPinnedObject(); + usearch_load_buffer(_index, unmanagedBuffer, (UIntPtr)managedBuffer.Length, out error); + HandleError(error); + } + finally + { + handle.Free(); + } + } + } + catch (Exception ex) + { + Debug.LogError($"Error loading the search index: {ex.Message}"); + } + + this._cachedDimensions = this.Dimensions(); + } + } +} diff --git a/ThirdParty/usearch/USearchIndex.cs.meta b/ThirdParty/usearch/USearchIndex.cs.meta new file mode 100644 index 00000000..623202bf --- /dev/null +++ b/ThirdParty/usearch/USearchIndex.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: bc9f068a1a8eab16fad8f42e4b0b013e +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/ThirdParty/usearch/USearchTypes.cs b/ThirdParty/usearch/USearchTypes.cs new file mode 100644 index 00000000..c780181b --- /dev/null +++ b/ThirdParty/usearch/USearchTypes.cs @@ -0,0 +1,72 @@ +using System; +using System.Runtime.InteropServices; + +namespace Cloud.Unum.USearch +{ + public enum MetricKind : uint + { + Unknown = 0, + Cos, + Ip, + L2sq, + Haversine, + Pearson, + Jaccard, + Hamming, + Tanimoto, + Sorensen, + } + + public enum ScalarKind : uint + { + Unknown = 0, + Float32, + Float64, + Float16, + Int8, + Byte1, + } + + // TODO: implement custom metric delegate + // Microsoft guides links: + // 1) https://learn.microsoft.com/en-us/dotnet/standard/native-interop/best-practices + // 2) https://learn.microsoft.com/en-us/dotnet/framework/interop/marshalling-a-delegate-as-a-callback-method + // public delegate float CustomMetricFunction(IntPtr a, IntPtr b); + + [Serializable] + [StructLayout(LayoutKind.Sequential)] + public struct IndexOptions + { + public MetricKind metric_kind; + public IntPtr metric; + public ScalarKind quantization; + public ulong dimensions; + public ulong connectivity; + public ulong expansion_add; + public ulong expansion_search; + + [MarshalAs(UnmanagedType.Bool)] + public bool multi; + + public IndexOptions( + MetricKind metricKind = MetricKind.Unknown, + IntPtr metric = default, + ScalarKind quantization = ScalarKind.Unknown, + ulong dimensions = 0, + ulong connectivity = 0, + ulong expansionAdd = 0, + ulong expansionSearch = 0, + bool multi = false + ) + { + this.metric_kind = metricKind; + this.metric = default; // TODO: Use actual metric param, when support is added for custom metric delegate + this.quantization = quantization; + this.dimensions = dimensions; + this.connectivity = connectivity; + this.expansion_add = expansionAdd; + this.expansion_search = expansionSearch; + this.multi = multi; + } + } +} diff --git a/ThirdParty/usearch/USearchTypes.cs.meta b/ThirdParty/usearch/USearchTypes.cs.meta new file mode 100644 index 00000000..5e16b24c --- /dev/null +++ b/ThirdParty/usearch/USearchTypes.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: cb3a11c4f75b85afea970e466d0f3120 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: