diff --git a/lib/ruby_llm.rb b/lib/ruby_llm.rb index 7bb5f2808..3f9bfc113 100644 --- a/lib/ruby_llm.rb +++ b/lib/ruby_llm.rb @@ -14,6 +14,7 @@ 'ruby_llm' => 'RubyLLM', 'llm' => 'LLM', 'openai' => 'OpenAI', + 'azure_openai' => 'AzureOpenAI', 'api' => 'API', 'deepseek' => 'DeepSeek', 'perplexity' => 'Perplexity', @@ -93,6 +94,7 @@ def logger RubyLLM::Provider.register :openrouter, RubyLLM::Providers::OpenRouter RubyLLM::Provider.register :perplexity, RubyLLM::Providers::Perplexity RubyLLM::Provider.register :vertexai, RubyLLM::Providers::VertexAI +RubyLLM::Provider.register :azure_openai, RubyLLM::Providers::AzureOpenAI if defined?(Rails::Railtie) require 'ruby_llm/railtie' diff --git a/lib/ruby_llm/configuration.rb b/lib/ruby_llm/configuration.rb index eda2c3354..5de8e0b02 100644 --- a/lib/ruby_llm/configuration.rb +++ b/lib/ruby_llm/configuration.rb @@ -23,6 +23,10 @@ class Configuration :gpustack_api_base, :gpustack_api_key, :mistral_api_key, + # Azure OpenAI Provider configuration + :azure_openai_api_base, + :azure_openai_api_version, + :azure_openai_api_key, # Default models :default_model, :default_embedding_model, diff --git a/lib/ruby_llm/providers/azure_openai.rb b/lib/ruby_llm/providers/azure_openai.rb new file mode 100644 index 000000000..9d680e41a --- /dev/null +++ b/lib/ruby_llm/providers/azure_openai.rb @@ -0,0 +1,43 @@ +# frozen_string_literal: true + +module RubyLLM + module Providers + # Azure OpenAI API integration. Derived from OpenAI integration to support + # OpenAI capabilities via Microsoft Azure endpoints. + module AzureOpenAI + extend OpenAI + extend AzureOpenAI::Chat + extend AzureOpenAI::Streaming + extend AzureOpenAI::Models + + module_function + + def api_base(config) + # https:///openai/deployments//chat/completions?api-version= + "#{config.azure_openai_api_base}/openai" + end + + def headers(config) + { + 'Authorization' => "Bearer #{config.azure_openai_api_key}" + }.compact + end + + def capabilities + OpenAI::Capabilities + end + + def slug + 'azure_openai' + end + + def configuration_requirements + %i[azure_openai_api_key azure_openai_api_base azure_openai_api_version] + end + + def local? + false + end + end + end +end diff --git a/lib/ruby_llm/providers/azure_openai/chat.rb b/lib/ruby_llm/providers/azure_openai/chat.rb new file mode 100644 index 000000000..b1e7f515b --- /dev/null +++ b/lib/ruby_llm/providers/azure_openai/chat.rb @@ -0,0 +1,31 @@ +# frozen_string_literal: true + +module RubyLLM + module Providers + module AzureOpenAI + # Chat methods of the Azure OpenAI API integration + module Chat + extend OpenAI::Chat + + module_function + + def sync_response(connection, payload) + # Hold config in instance variable for use in completion_url and stream_url + @config = connection.config + super + end + + def completion_url + # https:///openai/deployments//chat/completions?api-version= + "deployments/#{@model_id}/chat/completions?api-version=#{@config.azure_openai_api_version}" + end + + def render_payload(messages, tools:, temperature:, model:, stream: false) + # Hold model_id in instance variable for use in completion_url and stream_url + @model_id = model + super + end + end + end + end +end diff --git a/lib/ruby_llm/providers/azure_openai/models.rb b/lib/ruby_llm/providers/azure_openai/models.rb new file mode 100644 index 000000000..18573c170 --- /dev/null +++ b/lib/ruby_llm/providers/azure_openai/models.rb @@ -0,0 +1,33 @@ +# frozen_string_literal: true + +module RubyLLM + module Providers + module AzureOpenAI + # Models methods of the OpenAI API integration + module Models + extend OpenAI::Models + + KNOWN_MODELS = [ + 'gpt-4o' + ].freeze + + module_function + + def models_url + 'models?api-version=2024-10-21' + end + + def parse_list_models_response(response, slug, capabilities) + # select the known models only since this list from Azure OpenAI is + # very long + response.body['data'].select! do |m| + KNOWN_MODELS.include?(m['id']) + end + # Use the OpenAI processor for the list, keeping in mind that pricing etc + # won't be correct + super + end + end + end + end +end diff --git a/lib/ruby_llm/providers/azure_openai/streaming.rb b/lib/ruby_llm/providers/azure_openai/streaming.rb new file mode 100644 index 000000000..139ee2578 --- /dev/null +++ b/lib/ruby_llm/providers/azure_openai/streaming.rb @@ -0,0 +1,20 @@ +# frozen_string_literal: true + +module RubyLLM + module Providers + module AzureOpenAI + # Streaming methods of the Azure OpenAI API integration + module Streaming + extend OpenAI::Streaming + + module_function + + def stream_response(connection, payload, &) + # Hold config in instance variable for use in completion_url and stream_url + @config = connection.config + super + end + end + end + end +end diff --git a/lib/tasks/models_update.rake b/lib/tasks/models_update.rake new file mode 100644 index 000000000..5a8797704 --- /dev/null +++ b/lib/tasks/models_update.rake @@ -0,0 +1,86 @@ +# frozen_string_literal: true + +require 'dotenv/load' +require 'ruby_llm' + +task default: ['models:update'] + +namespace :models do + desc 'Update available models from providers (API keys needed)' + task :update do + puts 'Configuring RubyLLM...' + configure_from_env + + refresh_models + display_model_stats + end +end + +def configure_from_env + RubyLLM.configure do |config| + config.openai_api_key = ENV.fetch('OPENAI_API_KEY', nil) + config.anthropic_api_key = ENV.fetch('ANTHROPIC_API_KEY', nil) + config.gemini_api_key = ENV.fetch('GEMINI_API_KEY', nil) + config.deepseek_api_key = ENV.fetch('DEEPSEEK_API_KEY', nil) + config.openrouter_api_key = ENV.fetch('OPENROUTER_API_KEY', nil) + configure_bedrock(config) + configure_azure_openai(config) + config.request_timeout = 30 + end +end + +def configure_azure_openai(config) + config.azure_openai_api_base = ENV.fetch('AZURE_OPENAI_ENDPOINT', nil) + config.azure_openai_api_key = ENV.fetch('AZURE_OPENAI_API_KEY', nil) + config.azure_openai_api_version = ENV.fetch('AZURE_OPENAI_API_VER', nil) +end + +def configure_bedrock(config) + config.bedrock_api_key = ENV.fetch('AWS_ACCESS_KEY_ID', nil) + config.bedrock_secret_key = ENV.fetch('AWS_SECRET_ACCESS_KEY', nil) + config.bedrock_region = ENV.fetch('AWS_REGION', nil) + config.bedrock_session_token = ENV.fetch('AWS_SESSION_TOKEN', nil) +end + +def refresh_models + initial_count = RubyLLM.models.all.size + puts "Refreshing models (#{initial_count} cached)..." + + models = RubyLLM.models.refresh! + + if models.all.empty? && initial_count.zero? + puts 'Error: Failed to fetch models.' + exit(1) + elsif models.all.size == initial_count && initial_count.positive? + puts 'Warning: Model list unchanged.' + else + puts "Saving models.json (#{models.all.size} models)" + models.save_models + end + + @models = models +end + +def display_model_stats + puts "\nModel count:" + provider_counts = @models.all.group_by(&:provider).transform_values(&:count) + + RubyLLM::Provider.providers.each_key do |sym| + name = sym.to_s.capitalize + count = provider_counts[sym.to_s] || 0 + status = status(sym) + puts " #{name}: #{count} models #{status}" + end + + puts 'Refresh complete.' +end + +def status(provider_sym) + if RubyLLM::Provider.providers[provider_sym].local? + ' (LOCAL - SKIP)' + elsif RubyLLM::Provider.providers[provider_sym].configured? + ' (OK)' + else + ' (NOT CONFIGURED)' + end +end