-
Notifications
You must be signed in to change notification settings - Fork 1
/
ModelProperties.cs
139 lines (125 loc) · 5.06 KB
/
ModelProperties.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Nodes;
using OpenAI.Models;
namespace TextForge
{
internal class ModelProperties
{
// Public
public const int BaselineContextWindowLength = 4096; // Change this if necessary
public static List<string> UniqueEmbedModels { get { return _embedModels; } }
// Private
private static readonly List<string> _embedModels = new List<string>()
{
"all-minilm",
"bge-m3",
"bge-large",
"paraphrase-multilingual"
};
private static readonly Dictionary<string, int> openAIModelsContextLength = new Dictionary<string, int>()
{
{ "gpt-4-0125-preview", 128000 },
{ "gpt-4-1106-preview", 128000 },
{ "gpt-3.5-turbo-instruct", 4096 },
};
private static bool IsOllamaEndpoint = false;
private static bool IsOllamaFetched = false;
private static Dictionary<string, int> ollamaContextWindowCache = new Dictionary<string, int>();
private static readonly CultureLocalizationHelper _cultureHelper = new CultureLocalizationHelper("TextForge.Forge", typeof(Forge).Assembly);
public static int GetContextLength(string modelName, OpenAIModelCollection availableModels)
{
if (openAIModelsContextLength.ContainsKey(modelName))
{
return openAIModelsContextLength[modelName];
}
else if (modelName.Contains(':'))
{
try
{
if (!IsOllamaFetched)
{
IsOllamaEndpoint = IsOllama(availableModels);
IsOllamaFetched = true;
}
if (IsOllamaEndpoint)
{
int contextWindow;
if (!ollamaContextWindowCache.TryGetValue(modelName, out contextWindow))
{
contextWindow = GetOllamaModelContextWindow(modelName);
ollamaContextWindowCache[modelName] = contextWindow;
}
return contextWindow;
} else
{
return BaselineContextWindowLength;
}
} catch (OllamaMissingContextWindowException ex)
{
CommonUtils.DisplayWarning(ex);
return BaselineContextWindowLength;
}
}
else if (modelName.StartsWith("o1"))
{
return 128000;
}
else if (modelName.StartsWith("gpt-4-turbo"))
{
return 128000;
}
else if (modelName.StartsWith("gpt-4-mini"))
{
return 128000;
}
else if (modelName.StartsWith("gpt-4"))
{
return 8192;
}
else if (modelName.StartsWith("gpt-3.5-turbo"))
{
return 16385;
}
else
{
return BaselineContextWindowLength;
}
}
public static IEnumerable<string> GetModelList(OpenAIModelCollection availableModels)
{
return availableModels.Select(info => info.Id).ToList();
}
private static bool IsOllama(OpenAIModelCollection availableModels)
{
return (availableModels.Count == 0) ? false : availableModels.First().OwnedBy == "library";
}
private static int GetOllamaModelContextWindow(string model)
{
var ollamaEndpoint = ThisAddIn.OpenAIEndpoint.Replace("/v1", "");
Ollama ollamaInstance = new Ollama(new Uri(ollamaEndpoint));
var dict = ollamaInstance.Show(model, true).Result; // or await, if Show() is async
// Navigate to "model_info"
if (dict.TryGetValue("model_info", out var modelInfoObj) && modelInfoObj is JsonElement modelInfoElement)
{
// Use JsonNode or JsonElement to search for "context_length" key
var modelInfoNode = JsonNode.Parse(modelInfoElement.GetRawText());
foreach (var keyValuePair in modelInfoNode.AsObject())
{
// Search for a nested object containing "context_length"
if (keyValuePair.Key.EndsWith(".context_length"))
{
return int.Parse(keyValuePair.Value.ToString());
}
}
}
throw new OllamaMissingContextWindowException(string.Format(_cultureHelper.GetLocalizedString("(ModelProperties.cs) [GetContextLength] OllamaMissingContextWindowException #1"), model));
}
}
public class OllamaMissingContextWindowException : ApplicationException
{
public OllamaMissingContextWindowException(string message) : base(message) { }
}
}