From 9d1cbb6f2e9c24f14902e4db2e75d7468f5b8ee3 Mon Sep 17 00:00:00 2001 From: gelluisaac Date: Thu, 9 Oct 2025 15:58:11 +0100 Subject: [PATCH] Implement a base class for AI model integrations --- packages/core/src/ai/README.md | 332 ++++++++++++++ .../core/src/ai/__tests__/base-model.test.ts | 270 +++++++++++ .../core/src/ai/__tests__/providers.test.ts | 283 ++++++++++++ packages/core/src/ai/base-model.ts | 427 ++++++++++++++++++ packages/core/src/ai/examples/model-usage.ts | 287 ++++++++++++ packages/core/src/ai/index.ts | 19 +- .../core/src/ai/providers/custom-model.ts | 113 +++++ .../src/ai/providers/huggingface-model.ts | 119 +++++ packages/core/src/ai/providers/index.ts | 3 + .../core/src/ai/providers/openai-model.ts | 125 +++++ packages/core/src/ai/types.ts | 175 +++++++ packages/core/src/index.ts | 6 + 12 files changed, 2158 insertions(+), 1 deletion(-) create mode 100644 packages/core/src/ai/README.md create mode 100644 packages/core/src/ai/__tests__/base-model.test.ts create mode 100644 packages/core/src/ai/__tests__/providers.test.ts create mode 100644 packages/core/src/ai/base-model.ts create mode 100644 packages/core/src/ai/examples/model-usage.ts create mode 100644 packages/core/src/ai/providers/custom-model.ts create mode 100644 packages/core/src/ai/providers/huggingface-model.ts create mode 100644 packages/core/src/ai/providers/index.ts create mode 100644 packages/core/src/ai/providers/openai-model.ts diff --git a/packages/core/src/ai/README.md b/packages/core/src/ai/README.md new file mode 100644 index 0000000..71b044a --- /dev/null +++ b/packages/core/src/ai/README.md @@ -0,0 +1,332 @@ +# AI Model Base Class + +This module provides a comprehensive base class for AI model integrations that can be extended for different ML services (OpenAI, Hugging Face, custom models). + +## Features + +- **Abstract Base Class**: `AIModel` provides common functionality for all AI providers +- **Provider Implementations**: Ready-to-use implementations for OpenAI, Hugging Face, and custom models +- **Configuration Management**: Flexible configuration system supporting multiple API keys and model types +- **Error Handling**: Comprehensive error handling for network and API errors +- **Rate Limiting**: Built-in rate limiting with configurable limits +- **Batch Processing**: Support for processing multiple requests +- **TypeScript Support**: Full TypeScript interfaces and type safety +- **Testing**: Comprehensive unit tests for all functionality + +## Quick Start + +### Basic Usage + +```typescript +import { OpenAIModel } from './providers/openai-model'; +import { ModelInput } from './types'; + +const model = new OpenAIModel({ + apiKey: 'your-openai-key', + provider: 'openai', + modelVersion: 'gpt-3.5-turbo', +}); + +const input: ModelInput = { + prompt: 'Hello, world!', + maxTokens: 100, + temperature: 0.7, +}; + +const result = await model.generate(input); +console.log(result.text); +``` + +### Using Different Providers + +#### OpenAI +```typescript +import { OpenAIModel } from './providers/openai-model'; + +const model = new OpenAIModel({ + apiKey: 'your-openai-key', + provider: 'openai', + modelVersion: 'gpt-4', + organization: 'your-org-id', +}); +``` + +#### Hugging Face +```typescript +import { HuggingFaceModel } from './providers/huggingface-model'; + +const model = new HuggingFaceModel({ + apiKey: 'your-hf-key', + provider: 'huggingface', + modelVersion: 'microsoft/DialoGPT-medium', + useAuth: true, +}); +``` + +#### Custom Model +```typescript +import { CustomModel } from './providers/custom-model'; + +const model = new CustomModel({ + apiKey: 'your-custom-key', + provider: 'custom', + modelVersion: 'my-model-v1', + customEndpoint: 'https://api.mycompany.com/v1', +}); +``` + +## Architecture + +### Base Class (`AIModel`) + +The `AIModel` abstract class provides: + +- **Configuration Management**: API keys, timeouts, rate limits +- **Request/Response Formatting**: Standardized input/output interfaces +- **Error Handling**: Network errors, API errors, validation errors +- **Rate Limiting**: Configurable request throttling +- **Batch Processing**: Multiple request handling +- **Connection Testing**: Health checks for API endpoints + +### Provider Implementations + +#### OpenAI Model (`OpenAIModel`) +- Supports chat completions API +- Handles system messages and conversation history +- Configurable model versions (GPT-3.5, GPT-4, etc.) +- Organization support + +#### Hugging Face Model (`HuggingFaceModel`) +- Inference API integration +- Model status checking +- Support for various model types +- Authentication handling + +#### Custom Model (`CustomModel`) +- Generic implementation for custom APIs +- Flexible request/response parsing +- Configurable endpoints +- Extensible for specific needs + +## Configuration + +### Model Configuration + +```typescript +interface ModelConfig { + apiKey: string; // Required API key + baseUrl?: string; // API base URL + timeout?: number; // Request timeout (ms) + rateLimit?: number; // Requests per minute + headers?: Record; // Additional headers + modelOptions?: Record; // Model-specific options +} +``` + +### Provider-Specific Configuration + +#### OpenAI +```typescript +interface OpenAIConfig extends ModelConfig { + provider: 'openai'; + modelVersion: string; + organization?: string; +} +``` + +#### Hugging Face +```typescript +interface HuggingFaceConfig extends ModelConfig { + provider: 'huggingface'; + modelVersion: string; + useAuth?: boolean; +} +``` + +#### Custom Model +```typescript +interface CustomModelConfig extends ModelConfig { + provider: 'custom'; + modelVersion: string; + customEndpoint: string; +} +``` + +## Input/Output Interfaces + +### Model Input +```typescript +interface ModelInput { + prompt: string; // Required input text + maxTokens?: number; // Max tokens to generate + temperature?: number; // Response randomness (0-1) + topP?: number; // Top-p sampling (0-1) + stopSequences?: string[]; // Stop generation sequences + parameters?: Record; // Additional parameters + systemMessage?: string; // System message for chat models + messages?: Array<{ // Conversation history + role: 'system' | 'user' | 'assistant'; + content: string; + }>; +} +``` + +### Model Output +```typescript +interface ModelOutput { + text: string; // Generated text + tokensUsed?: number; // Tokens consumed + metadata?: { // Response metadata + model: string; + finishReason?: string; + responseTime?: number; + provider?: string; + }; + rawResponse?: any; // Raw API response + chunks?: string[]; // Streaming chunks +} +``` + +## Error Handling + +The base class provides comprehensive error handling: + +```typescript +class AIModelError extends Error { + constructor( + message: string, + public code: string, + public statusCode?: number, + public retryable: boolean = false + ); +} +``` + +### Error Types +- `INVALID_INPUT`: Input validation errors +- `INVALID_OUTPUT`: Output validation errors +- `RATE_LIMIT_EXCEEDED`: Rate limiting errors +- `NETWORK_ERROR`: Network connection errors +- `TIMEOUT`: Request timeout errors +- `API_ERROR`: API-specific errors +- `SERVER_ERROR`: Server-side errors + +## Rate Limiting + +Built-in rate limiting with configurable limits: + +```typescript +// Configure rate limiting +const model = new OpenAIModel({ + apiKey: 'your-key', + rateLimit: 60, // 60 requests per minute +}); + +// Check rate limit status +const status = model.getRateLimitStatus(); +console.log(`Current: ${status.current}/${status.max}`); +console.log(`Reset in: ${status.resetTime}ms`); +``` + +## Batch Processing + +Process multiple requests efficiently: + +```typescript +const inputs: ModelInput[] = [ + { prompt: 'First question' }, + { prompt: 'Second question' }, + { prompt: 'Third question' }, +]; + +const results = await model.generateBatch(inputs); +results.forEach((result, index) => { + console.log(`Response ${index + 1}:`, result.text); +}); +``` + +## Creating Custom Models + +Extend the base class for custom implementations: + +```typescript +class MyCustomModel extends AIModel { + protected async makeRequest(input: ModelInput): Promise { + // Implement your custom API logic + const response = await this.httpClient.post('/your-endpoint', { + text: input.prompt, + max_length: input.maxTokens, + }); + + return { + text: response.data.generated_text, + tokensUsed: response.data.tokens_used, + metadata: { + model: 'my-custom-model', + finishReason: 'stop', + provider: 'custom', + }, + }; + } + + protected getModelName(): string { + return 'my-custom-model'; + } +} +``` + +## Testing + +Comprehensive unit tests are provided: + +```bash +# Run tests +npm test + +# Run specific test files +npm test base-model.test.ts +npm test providers.test.ts +``` + +## Examples + +See `examples/model-usage.ts` for comprehensive usage examples including: + +- Basic model usage +- Provider-specific configurations +- Batch processing +- Error handling +- Rate limiting +- Custom implementations +- Configuration management + +## API Reference + +### AIModel Methods + +- `generate(input: ModelInput): Promise` - Generate text +- `generateBatch(inputs: ModelInput[]): Promise` - Batch processing +- `getConfig(): ModelConfig` - Get current configuration +- `updateConfig(config: Partial): void` - Update configuration +- `testConnection(): Promise` - Test API connection +- `getRateLimitStatus(): RateLimitStatus | null` - Get rate limit status + +### Abstract Methods (to implement) + +- `makeRequest(input: ModelInput): Promise` - Make API request +- `getModelName(): string` - Get model name + +## Contributing + +When adding new providers or extending functionality: + +1. Extend the `AIModel` base class +2. Implement required abstract methods +3. Add provider-specific configuration interfaces +4. Include comprehensive tests +5. Update documentation +6. Add usage examples + +## License + +MIT License - see LICENSE file for details. diff --git a/packages/core/src/ai/__tests__/base-model.test.ts b/packages/core/src/ai/__tests__/base-model.test.ts new file mode 100644 index 0000000..0e3b8dc --- /dev/null +++ b/packages/core/src/ai/__tests__/base-model.test.ts @@ -0,0 +1,270 @@ +import { AIModel, ModelConfig, ModelInput, ModelOutput, AIModelError } from '../base-model'; +import axios from 'axios'; + +// Mock axios +jest.mock('axios'); +const mockedAxios = axios as jest.Mocked; + +// Mock axios.create to return a mock instance +mockedAxios.create.mockReturnValue({ + interceptors: { + request: { use: jest.fn() }, + response: { use: jest.fn() }, + }, + defaults: {}, +} as any); + +// Test implementation of AIModel +class TestAIModel extends AIModel { + protected async makeRequest(input: ModelInput): Promise { + return { + text: 'Test response', + tokensUsed: 10, + metadata: { + model: 'test-model', + finishReason: 'stop', + }, + }; + } + + protected getModelName(): string { + return 'test-model'; + } +} + +describe('AIModel Base Class', () => { + let model: TestAIModel; + let config: ModelConfig; + + beforeEach(() => { + config = { + apiKey: 'test-api-key', + baseUrl: 'https://api.test.com', + timeout: 30000, + rateLimit: 60, + }; + model = new TestAIModel(config); + jest.clearAllMocks(); + }); + + describe('Constructor and Configuration', () => { + it('should initialize with default configuration', () => { + const defaultConfig: ModelConfig = { + apiKey: 'test-key', + }; + const testModel = new TestAIModel(defaultConfig); + + expect(testModel.getConfig()).toEqual({ + apiKey: 'test-key', + timeout: 30000, + rateLimit: 60, + }); + }); + + it('should merge provided configuration with defaults', () => { + const customConfig: ModelConfig = { + apiKey: 'custom-key', + timeout: 60000, + rateLimit: 120, + headers: { 'Custom-Header': 'value' }, + }; + const testModel = new TestAIModel(customConfig); + + expect(testModel.getConfig()).toEqual({ + apiKey: 'custom-key', + timeout: 60000, + rateLimit: 120, + headers: { 'Custom-Header': 'value' }, + }); + }); + }); + + describe('Input Validation', () => { + it('should validate required prompt', async () => { + const invalidInput = { prompt: '' } as ModelInput; + + await expect(model.generate(invalidInput)).rejects.toThrow(AIModelError); + await expect(model.generate(invalidInput)).rejects.toThrow('Prompt is required and must be a string'); + }); + + it('should validate maxTokens range', async () => { + const invalidInput: ModelInput = { + prompt: 'test', + maxTokens: 5000, + }; + + await expect(model.generate(invalidInput)).rejects.toThrow(AIModelError); + await expect(model.generate(invalidInput)).rejects.toThrow('maxTokens must be between 1 and 4000'); + }); + + it('should validate temperature range', async () => { + const invalidInput: ModelInput = { + prompt: 'test', + temperature: 3, + }; + + await expect(model.generate(invalidInput)).rejects.toThrow(AIModelError); + await expect(model.generate(invalidInput)).rejects.toThrow('temperature must be between 0 and 2'); + }); + + it('should validate topP range', async () => { + const invalidInput: ModelInput = { + prompt: 'test', + topP: 2, + }; + + await expect(model.generate(invalidInput)).rejects.toThrow(AIModelError); + await expect(model.generate(invalidInput)).rejects.toThrow('topP must be between 0 and 1'); + }); + + it('should accept valid input', async () => { + const validInput: ModelInput = { + prompt: 'test prompt', + maxTokens: 100, + temperature: 0.7, + topP: 0.9, + }; + + const result = await model.generate(validInput); + expect(result.text).toBe('Test response'); + }); + }); + + describe('Rate Limiting', () => { + it('should track rate limit status', () => { + const status = model.getRateLimitStatus(); + expect(status).toEqual({ + current: 0, + max: 60, + resetTime: 60000, + }); + }); + + it('should return null when rate limiting is disabled', () => { + const noLimitConfig: ModelConfig = { + apiKey: 'test-key', + rateLimit: undefined, + }; + const testModel = new TestAIModel(noLimitConfig); + + expect(testModel.getRateLimitStatus()).toBeNull(); + }); + }); + + describe('Batch Processing', () => { + it('should process multiple inputs', async () => { + const inputs: ModelInput[] = [ + { prompt: 'first prompt' }, + { prompt: 'second prompt' }, + { prompt: 'third prompt' }, + ]; + + const results = await model.generateBatch(inputs); + + expect(results).toHaveLength(3); + expect(results[0].text).toBe('Test response'); + expect(results[1].text).toBe('Test response'); + expect(results[2].text).toBe('Test response'); + }); + + it('should handle errors in batch processing', async () => { + // Mock the makeRequest method to throw an error for the second input + const originalMakeRequest = model['makeRequest']; + let callCount = 0; + model['makeRequest'] = jest.fn().mockImplementation(async (input: ModelInput) => { + callCount++; + if (callCount === 2) { + throw new Error('Test error'); + } + return originalMakeRequest.call(model, input); + }); + + const inputs: ModelInput[] = [ + { prompt: 'first prompt' }, + { prompt: 'second prompt' }, + { prompt: 'third prompt' }, + ]; + + const results = await model.generateBatch(inputs); + + expect(results).toHaveLength(3); + expect(results[0].text).toBe('Test response'); + expect(results[1].text).toBe(''); + expect(results[1].metadata?.finishReason).toBe('error'); + expect(results[2].text).toBe('Test response'); + }); + }); + + describe('Configuration Updates', () => { + it('should update configuration', () => { + const newConfig = { + timeout: 60000, + rateLimit: 120, + }; + + model.updateConfig(newConfig); + + const updatedConfig = model.getConfig(); + expect(updatedConfig.timeout).toBe(60000); + expect(updatedConfig.rateLimit).toBe(120); + expect(updatedConfig.apiKey).toBe('test-api-key'); // Should preserve existing values + }); + }); + + describe('Connection Testing', () => { + it('should test connection successfully', async () => { + const isConnected = await model.testConnection(); + expect(isConnected).toBe(true); + }); + + it('should handle connection test failures', async () => { + // Mock makeRequest to throw an error + model['makeRequest'] = jest.fn().mockRejectedValue(new Error('Connection failed')); + + const isConnected = await model.testConnection(); + expect(isConnected).toBe(false); + }); + }); + + describe('Error Handling', () => { + it('should handle network errors', async () => { + const networkError = new Error('Network error'); + (networkError as any).code = 'ENOTFOUND'; + + model['makeRequest'] = jest.fn().mockRejectedValue(networkError); + + await expect(model.generate({ prompt: 'test' })).rejects.toThrow(AIModelError); + await expect(model.generate({ prompt: 'test' })).rejects.toThrow('Network connection failed'); + }); + + it('should handle timeout errors', async () => { + const timeoutError = new Error('Timeout'); + (timeoutError as any).code = 'ECONNABORTED'; + + model['makeRequest'] = jest.fn().mockRejectedValue(timeoutError); + + await expect(model.generate({ prompt: 'test' })).rejects.toThrow(AIModelError); + await expect(model.generate({ prompt: 'test' })).rejects.toThrow('Request timeout'); + }); + + it('should handle unknown errors', async () => { + const unknownError = new Error('Unknown error'); + + model['makeRequest'] = jest.fn().mockRejectedValue(unknownError); + + await expect(model.generate({ prompt: 'test' })).rejects.toThrow(AIModelError); + await expect(model.generate({ prompt: 'test' })).rejects.toThrow('Unknown error occurred'); + }); + }); + + describe('Output Validation', () => { + it('should validate output format', async () => { + model['makeRequest'] = jest.fn().mockResolvedValue({ + text: '', // Invalid empty text + }); + + await expect(model.generate({ prompt: 'test' })).rejects.toThrow(AIModelError); + await expect(model.generate({ prompt: 'test' })).rejects.toThrow('Invalid output format: text is required'); + }); + }); +}); diff --git a/packages/core/src/ai/__tests__/providers.test.ts b/packages/core/src/ai/__tests__/providers.test.ts new file mode 100644 index 0000000..428726a --- /dev/null +++ b/packages/core/src/ai/__tests__/providers.test.ts @@ -0,0 +1,283 @@ +import { OpenAIModel } from '../providers/openai-model'; +import { HuggingFaceModel } from '../providers/huggingface-model'; +import { CustomModel } from '../providers/custom-model'; +import { OpenAIConfig, HuggingFaceConfig, CustomModelConfig } from '../types'; +import axios from 'axios'; + +// Mock axios +jest.mock('axios'); +const mockedAxios = axios as jest.Mocked; + +// Mock axios.create to return a mock instance +mockedAxios.create.mockReturnValue({ + interceptors: { + request: { use: jest.fn() }, + response: { use: jest.fn() }, + }, + defaults: {}, + post: jest.fn(), + get: jest.fn(), +} as any); + +describe('AI Model Providers', () => { + describe('OpenAIModel', () => { + let model: OpenAIModel; + let config: OpenAIConfig; + + beforeEach(() => { + config = { + apiKey: 'test-openai-key', + provider: 'openai', + modelVersion: 'gpt-3.5-turbo', + organization: 'test-org', + }; + model = new OpenAIModel(config); + jest.clearAllMocks(); + }); + + it('should initialize with correct configuration', () => { + expect(model.getConfig().baseUrl).toBe('https://api.openai.com/v1'); + expect(model.getConfig().apiKey).toBe('test-openai-key'); + }); + + it('should build correct request body for chat completion', async () => { + const mockResponse = { + data: { + choices: [{ message: { content: 'Test response' } }], + usage: { total_tokens: 10 }, + model: 'gpt-3.5-turbo', + }, + }; + + mockedAxios.create.mockReturnValue({ + post: jest.fn().mockResolvedValue(mockResponse), + interceptors: { + request: { use: jest.fn() }, + response: { use: jest.fn() }, + }, + defaults: {}, + } as any); + + const input = { + prompt: 'Hello, world!', + maxTokens: 100, + temperature: 0.7, + }; + + await model.generate(input); + + // Verify the request was made with correct parameters + expect(mockedAxios.create).toHaveBeenCalled(); + }); + + it('should handle system messages', async () => { + const mockResponse = { + data: { + choices: [{ message: { content: 'Test response' } }], + usage: { total_tokens: 10 }, + model: 'gpt-3.5-turbo', + }, + }; + + mockedAxios.create.mockReturnValue({ + post: jest.fn().mockResolvedValue(mockResponse), + interceptors: { + request: { use: jest.fn() }, + response: { use: jest.fn() }, + }, + defaults: {}, + } as any); + + const input = { + prompt: 'Hello, world!', + systemMessage: 'You are a helpful assistant.', + }; + + await model.generate(input); + expect(mockedAxios.create).toHaveBeenCalled(); + }); + + it('should return correct capabilities', () => { + const capabilities = model.getCapabilities(); + + expect(capabilities.textGeneration).toBe(true); + expect(capabilities.chat).toBe(true); + expect(capabilities.streaming).toBe(true); + expect(capabilities.batchProcessing).toBe(true); + expect(capabilities.maxContextLength).toBe(4096); + expect(capabilities.languages).toContain('en'); + }); + + it('should get correct max context length for different models', () => { + const gpt4Config: OpenAIConfig = { + ...config, + modelVersion: 'gpt-4', + }; + const gpt4Model = new OpenAIModel(gpt4Config); + + const capabilities = gpt4Model.getCapabilities(); + expect(capabilities.maxContextLength).toBe(8192); + }); + }); + + describe('HuggingFaceModel', () => { + let model: HuggingFaceModel; + let config: HuggingFaceConfig; + + beforeEach(() => { + config = { + apiKey: 'test-hf-key', + provider: 'huggingface', + modelVersion: 'microsoft/DialoGPT-medium', + useAuth: true, + }; + model = new HuggingFaceModel(config); + jest.clearAllMocks(); + }); + + it('should initialize with correct configuration', () => { + expect(model.getConfig().baseUrl).toBe('https://api-inference.huggingface.co/models'); + expect(model.getConfig().apiKey).toBe('test-hf-key'); + }); + + it('should return correct capabilities', () => { + const capabilities = model.getCapabilities(); + + expect(capabilities.textGeneration).toBe(true); + expect(capabilities.chat).toBe(false); + expect(capabilities.streaming).toBe(false); + expect(capabilities.batchProcessing).toBe(true); + expect(capabilities.maxContextLength).toBe(1024); + }); + + it('should check model status', async () => { + const mockResponse = { data: { status: 'loaded' } }; + + mockedAxios.create.mockReturnValue({ + get: jest.fn().mockResolvedValue(mockResponse), + post: jest.fn(), + interceptors: { + request: { use: jest.fn() }, + response: { use: jest.fn() }, + }, + defaults: {}, + } as any); + + const status = await model.checkModelStatus(); + expect(status.loaded).toBe(true); + expect(status.loading).toBe(false); + }); + + it('should handle model loading status', async () => { + const mockError = { + response: { status: 503 }, + }; + + mockedAxios.create.mockReturnValue({ + get: jest.fn().mockRejectedValue(mockError), + post: jest.fn(), + interceptors: { + request: { use: jest.fn() }, + response: { use: jest.fn() }, + }, + defaults: {}, + } as any); + + const status = await model.checkModelStatus(); + expect(status.loaded).toBe(false); + expect(status.loading).toBe(true); + }); + }); + + describe('CustomModel', () => { + let model: CustomModel; + let config: CustomModelConfig; + + beforeEach(() => { + config = { + apiKey: 'test-custom-key', + provider: 'custom', + modelVersion: 'custom-model-v1', + customEndpoint: 'https://api.custom.com/v1', + }; + model = new CustomModel(config); + jest.clearAllMocks(); + }); + + it('should initialize with custom endpoint', () => { + expect(model.getConfig().baseUrl).toBe('https://api.custom.com/v1'); + expect(model.getConfig().apiKey).toBe('test-custom-key'); + }); + + it('should return default capabilities', () => { + const capabilities = model.getCapabilities(); + + expect(capabilities.textGeneration).toBe(true); + expect(capabilities.chat).toBe(false); + expect(capabilities.streaming).toBe(false); + expect(capabilities.batchProcessing).toBe(true); + expect(capabilities.maxContextLength).toBe(2048); + }); + + it('should handle various response formats', async () => { + const mockResponse = { + data: { text: 'Custom response' }, + }; + + mockedAxios.create.mockReturnValue({ + post: jest.fn().mockResolvedValue(mockResponse), + interceptors: { + request: { use: jest.fn() }, + response: { use: jest.fn() }, + }, + defaults: {}, + } as any); + + const input = { prompt: 'test' }; + const result = await model.generate(input); + + expect(result.text).toBe('Custom response'); + expect(result.metadata?.provider).toBe('custom'); + }); + + it('should handle string responses', async () => { + const mockResponse = { + data: 'Simple string response', + }; + + mockedAxios.create.mockReturnValue({ + post: jest.fn().mockResolvedValue(mockResponse), + interceptors: { + request: { use: jest.fn() }, + response: { use: jest.fn() }, + }, + defaults: {}, + } as any); + + const input = { prompt: 'test' }; + const result = await model.generate(input); + + expect(result.text).toBe('Simple string response'); + }); + + it('should handle array responses', async () => { + const mockResponse = { + data: [{ generated_text: 'Generated text' }], + }; + + mockedAxios.create.mockReturnValue({ + post: jest.fn().mockResolvedValue(mockResponse), + interceptors: { + request: { use: jest.fn() }, + response: { use: jest.fn() }, + }, + defaults: {}, + } as any); + + const input = { prompt: 'test' }; + const result = await model.generate(input); + + expect(result.text).toBe('Generated text'); + }); + }); +}); diff --git a/packages/core/src/ai/base-model.ts b/packages/core/src/ai/base-model.ts new file mode 100644 index 0000000..214f3ea --- /dev/null +++ b/packages/core/src/ai/base-model.ts @@ -0,0 +1,427 @@ +import axios, { AxiosInstance, AxiosRequestConfig, AxiosResponse } from 'axios'; + +/** + * Configuration options for AI model providers + */ +export interface ModelConfig { + /** API key for authentication */ + apiKey: string; + /** Base URL for the API endpoint */ + baseUrl?: string; + /** Request timeout in milliseconds */ + timeout?: number; + /** Maximum number of requests per minute */ + rateLimit?: number; + /** Additional headers to include in requests */ + headers?: Record; + /** Model-specific configuration */ + modelOptions?: Record; +} + +/** + * Standard input format for AI model requests + */ +export interface ModelInput { + /** The input prompt or text */ + prompt: string; + /** Maximum number of tokens to generate */ + maxTokens?: number; + /** Temperature for response randomness (0-1) */ + temperature?: number; + /** Top-p sampling parameter */ + topP?: number; + /** Stop sequences to end generation */ + stopSequences?: string[]; + /** Additional model-specific parameters */ + parameters?: Record; + /** System message for conversational models */ + systemMessage?: string; + /** Conversation history for chat models */ + messages?: Array<{ + role: 'system' | 'user' | 'assistant'; + content: string; + }>; +} + +/** + * Standard output format for AI model responses + */ +export interface ModelOutput { + /** The generated text response */ + text: string; + /** Number of tokens used in the request */ + tokensUsed?: number; + /** Model metadata */ + metadata?: { + model: string; + finishReason?: string; + responseTime?: number; + provider?: string; + }; + /** Raw response from the API */ + rawResponse?: any; +} + +/** + * Error types for AI model operations + */ +export class AIModelError extends Error { + constructor( + message: string, + public code: string, + public statusCode?: number, + public retryable: boolean = false + ) { + super(message); + this.name = 'AIModelError'; + } +} + +/** + * Rate limiting configuration + */ +interface RateLimitConfig { + maxRequests: number; + windowMs: number; + currentRequests: number; + windowStart: number; +} + +/** + * Abstract base class for AI model integrations + * Provides common functionality for different ML services + */ +export abstract class AIModel { + protected config: ModelConfig; + protected httpClient: AxiosInstance; + protected rateLimitConfig: RateLimitConfig | null = null; + + constructor(config: ModelConfig) { + this.config = { + timeout: 30000, + rateLimit: 60, + ...config, + }; + + // Initialize HTTP client + this.httpClient = axios.create({ + baseURL: this.config.baseUrl, + timeout: this.config.timeout, + headers: { + 'Content-Type': 'application/json', + ...this.config.headers, + }, + }); + + // Setup rate limiting if configured + if (this.config.rateLimit) { + this.rateLimitConfig = { + maxRequests: this.config.rateLimit, + windowMs: 60000, // 1 minute + currentRequests: 0, + windowStart: Date.now(), + }; + } + + // Setup request/response interceptors + this.setupInterceptors(); + } + + /** + * Abstract method to be implemented by concrete model classes + * Handles the actual API call to the specific model provider + */ + protected abstract makeRequest(input: ModelInput): Promise; + + /** + * Public method to generate text using the model + */ + async generate(input: ModelInput): Promise { + try { + // Validate input + this.validateInput(input); + + // Check rate limits + await this.checkRateLimit(); + + // Make the request + const result = await this.makeRequest(input); + + // Validate output + this.validateOutput(result); + + return result; + } catch (error) { + throw this.handleError(error); + } + } + + /** + * Batch processing for multiple inputs + */ + async generateBatch(inputs: ModelInput[]): Promise { + const results: ModelOutput[] = []; + + for (const input of inputs) { + try { + const result = await this.generate(input); + results.push(result); + } catch (error) { + // Add error result for failed requests + results.push({ + text: '', + metadata: { + model: this.getModelName(), + finishReason: 'error', + }, + rawResponse: { error: error instanceof Error ? error.message : 'Unknown error' }, + }); + } + } + + return results; + } + + /** + * Get the model name (to be implemented by concrete classes) + */ + protected abstract getModelName(): string; + + /** + * Setup HTTP interceptors for request/response handling + */ + private setupInterceptors(): void { + // Request interceptor for authentication + this.httpClient.interceptors.request.use( + (config) => { + // Add authentication header + if (config.headers) { + config.headers.Authorization = `Bearer ${this.config.apiKey}`; + } + return config; + }, + (error) => Promise.reject(error) + ); + + // Response interceptor for error handling + this.httpClient.interceptors.response.use( + (response) => response, + (error) => { + if (error.response) { + // API returned an error response + const statusCode = error.response.status; + const message = error.response.data?.message || error.message; + + if (statusCode === 429) { + throw new AIModelError( + 'Rate limit exceeded', + 'RATE_LIMIT_EXCEEDED', + statusCode, + true + ); + } else if (statusCode >= 500) { + throw new AIModelError( + 'Server error', + 'SERVER_ERROR', + statusCode, + true + ); + } else if (statusCode === 401) { + throw new AIModelError( + 'Invalid API key', + 'INVALID_API_KEY', + statusCode, + false + ); + } else { + throw new AIModelError( + message, + 'API_ERROR', + statusCode, + false + ); + } + } else if (error.request) { + // Network error + throw new AIModelError( + 'Network error - no response received', + 'NETWORK_ERROR', + undefined, + true + ); + } else { + // Other error + throw new AIModelError( + error.message, + 'UNKNOWN_ERROR', + undefined, + false + ); + } + } + ); + } + + /** + * Check and enforce rate limits + */ + private async checkRateLimit(): Promise { + if (!this.rateLimitConfig) return; + + const now = Date.now(); + const { maxRequests, windowMs, currentRequests, windowStart } = this.rateLimitConfig; + + // Reset window if needed + if (now - windowStart >= windowMs) { + this.rateLimitConfig.currentRequests = 0; + this.rateLimitConfig.windowStart = now; + } + + // Check if we've exceeded the rate limit + if (currentRequests >= maxRequests) { + const waitTime = windowMs - (now - windowStart); + throw new AIModelError( + `Rate limit exceeded. Try again in ${Math.ceil(waitTime / 1000)} seconds`, + 'RATE_LIMIT_EXCEEDED', + undefined, + true + ); + } + + // Increment request count + this.rateLimitConfig.currentRequests++; + } + + /** + * Validate input parameters + */ + private validateInput(input: ModelInput): void { + if (!input.prompt || typeof input.prompt !== 'string') { + throw new AIModelError('Prompt is required and must be a string', 'INVALID_INPUT'); + } + + if (input.maxTokens && (input.maxTokens < 1 || input.maxTokens > 4000)) { + throw new AIModelError('maxTokens must be between 1 and 4000', 'INVALID_INPUT'); + } + + if (input.temperature && (input.temperature < 0 || input.temperature > 2)) { + throw new AIModelError('temperature must be between 0 and 2', 'INVALID_INPUT'); + } + + if (input.topP && (input.topP < 0 || input.topP > 1)) { + throw new AIModelError('topP must be between 0 and 1', 'INVALID_INPUT'); + } + } + + /** + * Validate output format + */ + private validateOutput(output: ModelOutput): void { + if (!output.text || typeof output.text !== 'string') { + throw new AIModelError('Invalid output format: text is required', 'INVALID_OUTPUT'); + } + } + + /** + * Handle and transform errors + */ + private handleError(error: any): AIModelError { + if (error instanceof AIModelError) { + return error; + } + + if (error.code === 'ECONNABORTED') { + return new AIModelError( + 'Request timeout', + 'TIMEOUT', + undefined, + true + ); + } + + if (error.code === 'ENOTFOUND' || error.code === 'ECONNREFUSED') { + return new AIModelError( + 'Network connection failed', + 'NETWORK_ERROR', + undefined, + true + ); + } + + return new AIModelError( + error.message || 'Unknown error occurred', + 'UNKNOWN_ERROR', + undefined, + false + ); + } + + /** + * Get current configuration + */ + getConfig(): ModelConfig { + return { ...this.config }; + } + + /** + * Update configuration + */ + updateConfig(newConfig: Partial): void { + this.config = { ...this.config, ...newConfig }; + + // Update HTTP client configuration + if (newConfig.baseUrl) { + this.httpClient.defaults.baseURL = newConfig.baseUrl; + } + + if (newConfig.timeout) { + this.httpClient.defaults.timeout = newConfig.timeout; + } + + // Update rate limiting + if (newConfig.rateLimit) { + this.rateLimitConfig = { + maxRequests: newConfig.rateLimit, + windowMs: 60000, + currentRequests: 0, + windowStart: Date.now(), + }; + } + } + + /** + * Test the connection to the model API + */ + async testConnection(): Promise { + try { + const testInput: ModelInput = { + prompt: 'test', + maxTokens: 1, + }; + + await this.generate(testInput); + return true; + } catch (error) { + return false; + } + } + + /** + * Get rate limit status + */ + getRateLimitStatus(): { current: number; max: number; resetTime: number } | null { + if (!this.rateLimitConfig) return null; + + const now = Date.now(); + const { maxRequests, windowMs, currentRequests, windowStart } = this.rateLimitConfig; + + const resetTime = windowStart + windowMs; + const timeUntilReset = Math.max(0, resetTime - now); + + return { + current: currentRequests, + max: maxRequests, + resetTime: timeUntilReset, + }; + } +} diff --git a/packages/core/src/ai/examples/model-usage.ts b/packages/core/src/ai/examples/model-usage.ts new file mode 100644 index 0000000..285621e --- /dev/null +++ b/packages/core/src/ai/examples/model-usage.ts @@ -0,0 +1,287 @@ +/** + * Example usage of the AI Model base class and providers + * Demonstrates how to extend and use the base AIModel class + */ + +import { AIModel, ModelInput, ModelOutput, ModelConfig } from '../base-model'; +import { OpenAIModel } from '../providers/openai-model'; +import { HuggingFaceModel } from '../providers/huggingface-model'; +import { CustomModel } from '../providers/custom-model'; +import { OpenAIConfig, HuggingFaceConfig, CustomModelConfig } from '../types'; + +// Example 1: Using OpenAI Model +export async function openaiExample() { + const config: OpenAIConfig = { + apiKey: process.env.OPENAI_API_KEY || 'your-openai-key', + provider: 'openai', + modelVersion: 'gpt-3.5-turbo', + organization: 'your-org-id', + timeout: 30000, + rateLimit: 60, + }; + + const model = new OpenAIModel(config); + + const input: ModelInput = { + prompt: 'Explain quantum computing in simple terms', + maxTokens: 150, + temperature: 0.7, + systemMessage: 'You are a helpful science teacher.', + }; + + try { + const result = await model.generate(input); + console.log('OpenAI Response:', result.text); + console.log('Tokens used:', result.tokensUsed); + console.log('Model:', result.metadata?.model); + } catch (error) { + console.error('OpenAI Error:', error); + } +} + +// Example 2: Using Hugging Face Model +export async function huggingfaceExample() { + const config: HuggingFaceConfig = { + apiKey: process.env.HUGGINGFACE_API_KEY || 'your-hf-key', + provider: 'huggingface', + modelVersion: 'microsoft/DialoGPT-medium', + useAuth: true, + timeout: 30000, + }; + + const model = new HuggingFaceModel(config); + + const input: ModelInput = { + prompt: 'The future of artificial intelligence is', + maxTokens: 50, + temperature: 0.8, + }; + + try { + // Check if model is loaded + const status = await model.checkModelStatus(); + if (!status.loaded) { + console.log('Model is loading, please wait...'); + return; + } + + const result = await model.generate(input); + console.log('Hugging Face Response:', result.text); + } catch (error) { + console.error('Hugging Face Error:', error); + } +} + +// Example 3: Using Custom Model +export async function customModelExample() { + const config: CustomModelConfig = { + apiKey: process.env.CUSTOM_API_KEY || 'your-custom-key', + provider: 'custom', + modelVersion: 'my-custom-model-v1', + customEndpoint: 'https://api.mycompany.com/v1/chat', + timeout: 30000, + headers: { + 'X-Custom-Header': 'value', + }, + }; + + const model = new CustomModel(config); + + const input: ModelInput = { + prompt: 'Generate a creative story about a robot', + maxTokens: 200, + temperature: 0.9, + }; + + try { + const result = await model.generate(input); + console.log('Custom Model Response:', result.text); + } catch (error) { + console.error('Custom Model Error:', error); + } +} + +// Example 4: Batch Processing +export async function batchProcessingExample() { + const config: OpenAIConfig = { + apiKey: process.env.OPENAI_API_KEY || 'your-openai-key', + provider: 'openai', + modelVersion: 'gpt-3.5-turbo', + }; + + const model = new OpenAIModel(config); + + const inputs: ModelInput[] = [ + { prompt: 'What is machine learning?' }, + { prompt: 'Explain neural networks' }, + { prompt: 'What is deep learning?' }, + ]; + + try { + const results = await model.generateBatch(inputs); + + results.forEach((result, index) => { + console.log(`Question ${index + 1}:`, inputs[index].prompt); + console.log(`Answer:`, result.text); + console.log('---'); + }); + } catch (error) { + console.error('Batch Processing Error:', error); + } +} + +// Example 5: Custom Model Implementation +class MyCustomModel extends AIModel { + protected async makeRequest(input: ModelInput): Promise { + // Custom implementation for your specific API + const response = await this.httpClient.post('/generate', { + text: input.prompt, + max_length: input.maxTokens || 100, + temperature: input.temperature || 0.7, + }); + + return { + text: response.data.generated_text, + tokensUsed: response.data.tokens_used, + metadata: { + model: 'my-custom-model', + finishReason: 'stop', + provider: 'custom', + }, + rawResponse: response.data, + }; + } + + protected getModelName(): string { + return 'my-custom-model'; + } + + getCapabilities() { + return { + textGeneration: true, + chat: false, + streaming: false, + batchProcessing: true, + maxContextLength: 1024, + languages: ['en'], + }; + } +} + +export async function customImplementationExample() { + const config: ModelConfig = { + apiKey: 'your-api-key', + baseUrl: 'https://api.yourcompany.com', + timeout: 30000, + rateLimit: 100, + }; + + const model = new MyCustomModel(config); + + const input: ModelInput = { + prompt: 'Generate a product description for a smartwatch', + maxTokens: 100, + temperature: 0.8, + }; + + try { + const result = await model.generate(input); + console.log('Custom Implementation Response:', result.text); + } catch (error) { + console.error('Custom Implementation Error:', error); + } +} + +// Example 6: Error Handling and Retry Logic +export async function errorHandlingExample() { + const config: OpenAIConfig = { + apiKey: process.env.OPENAI_API_KEY || 'your-openai-key', + provider: 'openai', + modelVersion: 'gpt-3.5-turbo', + retryConfig: { + maxRetries: 3, + retryDelay: 1000, + backoffMultiplier: 2, + }, + }; + + const model = new OpenAIModel(config); + + const input: ModelInput = { + prompt: 'Test prompt', + maxTokens: 10, + }; + + try { + const result = await model.generate(input); + console.log('Success:', result.text); + } catch (error: any) { + if (error.code === 'RATE_LIMIT_EXCEEDED') { + console.log('Rate limit exceeded, please wait...'); + const status = model.getRateLimitStatus(); + if (status) { + console.log(`Try again in ${Math.ceil(status.resetTime / 1000)} seconds`); + } + } else if (error.retryable) { + console.log('Retryable error occurred:', error.message); + } else { + console.log('Non-retryable error:', error.message); + } + } +} + +// Example 7: Configuration Management +export async function configurationExample() { + const model = new OpenAIModel({ + apiKey: 'initial-key', + provider: 'openai', + modelVersion: 'gpt-3.5-turbo', + timeout: 30000, + }); + + console.log('Initial config:', model.getConfig()); + + // Update configuration + model.updateConfig({ + apiKey: 'new-key', + timeout: 60000, + rateLimit: 120, + }); + + console.log('Updated config:', model.getConfig()); + + // Test connection + const isConnected = await model.testConnection(); + console.log('Connection status:', isConnected); +} + +// Example 8: Rate Limiting +export async function rateLimitingExample() { + const model = new OpenAIModel({ + apiKey: 'your-key', + provider: 'openai', + modelVersion: 'gpt-3.5-turbo', + rateLimit: 2, // Very low limit for testing + }); + + const input: ModelInput = { + prompt: 'Test', + maxTokens: 5, + }; + + try { + // Make multiple requests quickly + const promises = Array(5).fill(null).map(() => model.generate(input)); + const results = await Promise.allSettled(promises); + + results.forEach((result, index) => { + if (result.status === 'fulfilled') { + console.log(`Request ${index + 1}: Success`); + } else { + console.log(`Request ${index + 1}: Failed - ${result.reason.message}`); + } + }); + } catch (error) { + console.error('Rate limiting error:', error); + } +} diff --git a/packages/core/src/ai/index.ts b/packages/core/src/ai/index.ts index 1d6056d..7e61534 100644 --- a/packages/core/src/ai/index.ts +++ b/packages/core/src/ai/index.ts @@ -1,5 +1,22 @@ export * from './service'; export * from './credit-scorer'; export * from './fraud-detector'; -export * from './types'; export * from './recommendations'; +export * from './base-model'; +export * from './providers'; + +// Re-export types with explicit names to avoid conflicts +export type { + ModelInput, + ModelOutput, + ModelConfig, + ExtendedModelConfig, + OpenAIConfig, + HuggingFaceConfig, + CustomModelConfig, + BatchConfig, + ModelCapabilities, + RateLimitConfig, +} from './types'; + +export { AIModelError } from './types'; diff --git a/packages/core/src/ai/providers/custom-model.ts b/packages/core/src/ai/providers/custom-model.ts new file mode 100644 index 0000000..5c4f9e1 --- /dev/null +++ b/packages/core/src/ai/providers/custom-model.ts @@ -0,0 +1,113 @@ +import { AIModel, ModelConfig, ModelInput, ModelOutput } from '../base-model'; +import { CustomModelConfig } from '../types'; + +/** + * Custom model implementation + * Extends the base AIModel class for custom API integration + */ +export class CustomModel extends AIModel { + private modelVersion: string; + private customEndpoint: string; + + constructor(config: CustomModelConfig) { + super({ + ...config, + baseUrl: config.baseUrl || config.customEndpoint, + }); + + this.modelVersion = config.modelVersion; + this.customEndpoint = config.customEndpoint; + } + + protected async makeRequest(input: ModelInput): Promise { + const requestBody = this.buildRequestBody(input); + + const response = await this.httpClient.post('/', requestBody); + + return this.parseResponse(response); + } + + protected getModelName(): string { + return this.modelVersion; + } + + /** + * Build custom model request body + * This is a generic implementation that can be overridden + */ + private buildRequestBody(input: ModelInput): any { + return { + prompt: input.prompt, + max_tokens: input.maxTokens, + temperature: input.temperature, + top_p: input.topP, + stop: input.stopSequences, + ...input.parameters, + }; + } + + /** + * Parse custom model response + * This is a generic implementation that can be overridden + */ + private parseResponse(response: any): ModelOutput { + const data = response.data; + + // Try to extract text from common response formats + let text: string; + + if (data.text) { + text = data.text; + } else if (data.response) { + text = data.response; + } else if (data.output) { + text = data.output; + } else if (data.content) { + text = data.content; + } else if (typeof data === 'string') { + text = data; + } else { + text = JSON.stringify(data); + } + + return { + text, + tokensUsed: data.tokens_used || data.tokensUsed, + metadata: { + model: this.modelVersion, + finishReason: data.finish_reason || 'stop', + provider: 'custom', + }, + rawResponse: data, + }; + } + + /** + * Get model capabilities + * Default implementation - should be overridden for specific models + */ + getCapabilities() { + return { + textGeneration: true, + chat: false, + streaming: false, + batchProcessing: true, + maxContextLength: 2048, + languages: ['en'], + }; + } + + /** + * Override this method to customize request body format + */ + protected buildCustomRequestBody(input: ModelInput): any { + return this.buildRequestBody(input); + } + + /** + * Override this method to customize response parsing + */ + protected parseCustomResponse(response: any): ModelOutput { + return this.parseResponse(response); + } +} diff --git a/packages/core/src/ai/providers/huggingface-model.ts b/packages/core/src/ai/providers/huggingface-model.ts new file mode 100644 index 0000000..f2ca12d --- /dev/null +++ b/packages/core/src/ai/providers/huggingface-model.ts @@ -0,0 +1,119 @@ +import { AIModel, ModelConfig, ModelInput, ModelOutput } from '../base-model'; +import { HuggingFaceConfig } from '../types'; + +/** + * Hugging Face model implementation + * Extends the base AIModel class for Hugging Face API integration + */ +export class HuggingFaceModel extends AIModel { + private modelVersion: string; + private useAuth: boolean; + + constructor(config: HuggingFaceConfig) { + super({ + ...config, + baseUrl: config.baseUrl || 'https://api-inference.huggingface.co/models', + }); + + this.modelVersion = config.modelVersion; + this.useAuth = config.useAuth || false; + } + + protected async makeRequest(input: ModelInput): Promise { + const requestBody = this.buildRequestBody(input); + + const response = await this.httpClient.post(`/${this.modelVersion}`, requestBody); + + return this.parseResponse(response); + } + + protected getModelName(): string { + return this.modelVersion; + } + + /** + * Build Hugging Face-specific request body + */ + private buildRequestBody(input: ModelInput): any { + return { + inputs: input.prompt, + parameters: { + max_new_tokens: input.maxTokens || 100, + temperature: input.temperature || 0.7, + top_p: input.topP || 0.9, + do_sample: true, + return_full_text: false, + ...input.parameters, + }, + }; + } + + /** + * Parse Hugging Face response format + */ + private parseResponse(response: any): ModelOutput { + const data = response.data; + + // Handle different response formats + let text: string; + let tokensUsed: number | undefined; + + if (Array.isArray(data) && data.length > 0) { + // Text generation response + text = data[0].generated_text || data[0].text || ''; + } else if (data.generated_text) { + // Single response + text = data.generated_text; + } else if (data.text) { + text = data.text; + } else { + text = JSON.stringify(data); + } + + return { + text, + tokensUsed, + metadata: { + model: this.modelVersion, + finishReason: 'stop', + provider: 'huggingface', + }, + rawResponse: data, + }; + } + + /** + * Get model capabilities + */ + getCapabilities() { + return { + textGeneration: true, + chat: false, // Most HF models are not conversational + streaming: false, + batchProcessing: true, + maxContextLength: 1024, // Default for most HF models + languages: ['en'], // Depends on the specific model + }; + } + + /** + * Check if model is loaded (Hugging Face specific) + */ + async checkModelStatus(): Promise<{ loaded: boolean; loading?: boolean }> { + try { + const response = await this.httpClient.get(`/${this.modelVersion}`); + return { + loaded: true, + loading: false, + }; + } catch (error: any) { + if (error.response?.status === 503) { + return { + loaded: false, + loading: true, + }; + } + throw error; + } + } +} diff --git a/packages/core/src/ai/providers/index.ts b/packages/core/src/ai/providers/index.ts new file mode 100644 index 0000000..d093e0d --- /dev/null +++ b/packages/core/src/ai/providers/index.ts @@ -0,0 +1,3 @@ +export { OpenAIModel } from './openai-model'; +export { HuggingFaceModel } from './huggingface-model'; +export { CustomModel } from './custom-model'; diff --git a/packages/core/src/ai/providers/openai-model.ts b/packages/core/src/ai/providers/openai-model.ts new file mode 100644 index 0000000..6edd05d --- /dev/null +++ b/packages/core/src/ai/providers/openai-model.ts @@ -0,0 +1,125 @@ +import { AIModel, ModelConfig, ModelInput, ModelOutput } from '../base-model'; +import { OpenAIConfig } from '../types'; + +/** + * OpenAI model implementation + * Extends the base AIModel class for OpenAI API integration + */ +export class OpenAIModel extends AIModel { + private modelVersion: string; + private organization?: string; + + constructor(config: OpenAIConfig) { + super({ + ...config, + baseUrl: config.baseUrl || 'https://api.openai.com/v1', + }); + + this.modelVersion = config.modelVersion; + this.organization = config.organization; + } + + protected async makeRequest(input: ModelInput): Promise { + const requestBody = this.buildRequestBody(input); + + const response = await this.httpClient.post('/chat/completions', requestBody); + + return this.parseResponse(response); + } + + protected getModelName(): string { + return this.modelVersion; + } + + /** + * Build OpenAI-specific request body + */ + private buildRequestBody(input: ModelInput): any { + const messages = this.buildMessages(input); + + return { + model: this.modelVersion, + messages, + max_tokens: input.maxTokens || 1000, + temperature: input.temperature || 0.7, + top_p: input.topP || 1, + stop: input.stopSequences, + ...input.parameters, + }; + } + + /** + * Build messages array for OpenAI chat format + */ + private buildMessages(input: ModelInput): Array<{ role: string; content: string }> { + const messages: Array<{ role: string; content: string }> = []; + + // Add system message if provided + if (input.systemMessage) { + messages.push({ + role: 'system', + content: input.systemMessage, + }); + } + + // Use provided messages or create from prompt + if (input.messages && input.messages.length > 0) { + messages.push(...input.messages); + } else { + messages.push({ + role: 'user', + content: input.prompt, + }); + } + + return messages; + } + + /** + * Parse OpenAI response format + */ + private parseResponse(response: any): ModelOutput { + const choice = response.data.choices[0]; + const usage = response.data.usage; + + return { + text: choice.message.content, + tokensUsed: usage?.total_tokens, + metadata: { + model: response.data.model, + finishReason: choice.finish_reason, + provider: 'openai', + }, + rawResponse: response.data, + }; + } + + /** + * Get model capabilities + */ + getCapabilities() { + return { + textGeneration: true, + chat: true, + streaming: true, + batchProcessing: true, + maxContextLength: this.getMaxContextLength(), + languages: ['en', 'es', 'fr', 'de', 'it', 'pt', 'ru', 'ja', 'ko', 'zh'], + }; + } + + /** + * Get maximum context length based on model version + */ + private getMaxContextLength(): number { + const contextLengths: Record = { + 'gpt-3.5-turbo': 4096, + 'gpt-3.5-turbo-16k': 16384, + 'gpt-4': 8192, + 'gpt-4-32k': 32768, + 'gpt-4-turbo': 128000, + }; + + return contextLengths[this.modelVersion] || 4096; + } +} diff --git a/packages/core/src/ai/types.ts b/packages/core/src/ai/types.ts index 2daf35a..6da656a 100644 --- a/packages/core/src/ai/types.ts +++ b/packages/core/src/ai/types.ts @@ -4,6 +4,24 @@ export interface AIConfig { timeout?: number; } +/** + * Base configuration for AI model providers + */ +export interface ModelConfig { + /** API key for authentication */ + apiKey: string; + /** Base URL for the API endpoint */ + baseUrl?: string; + /** Request timeout in milliseconds */ + timeout?: number; + /** Maximum number of requests per minute */ + rateLimit?: number; + /** Additional headers to include in requests */ + headers?: Record; + /** Model-specific configuration */ + modelOptions?: Record; +} + export interface CreditScoreResult { score: number; factors: string[]; @@ -15,3 +33,160 @@ export interface FraudDetectionResult { riskScore: number; factors: string[]; } + +/** + * Extended configuration for AI model providers + * Extends the base ModelConfig from base-model.ts + */ +export interface ExtendedModelConfig { + /** API key for authentication */ + apiKey: string; + /** Base URL for the API endpoint */ + baseUrl?: string; + /** Request timeout in milliseconds */ + timeout?: number; + /** Maximum number of requests per minute */ + rateLimit?: number; + /** Additional headers to include in requests */ + headers?: Record; + /** Model-specific configuration */ + modelOptions?: Record; + /** Provider-specific settings */ + provider?: 'openai' | 'huggingface' | 'custom'; + /** Model version or identifier */ + modelVersion?: string; + /** Enable/disable streaming responses */ + streaming?: boolean; + /** Retry configuration */ + retryConfig?: { + maxRetries: number; + retryDelay: number; + backoffMultiplier: number; + }; +} + +/** + * Standard input format for AI model requests + * Provides a consistent interface across different providers + */ +export interface ModelInput { + /** The input prompt or text */ + prompt: string; + /** Maximum number of tokens to generate */ + maxTokens?: number; + /** Temperature for response randomness (0-1) */ + temperature?: number; + /** Top-p sampling parameter */ + topP?: number; + /** Stop sequences to end generation */ + stopSequences?: string[]; + /** Additional model-specific parameters */ + parameters?: Record; + /** System message for conversational models */ + systemMessage?: string; + /** Conversation history for chat models */ + messages?: Array<{ + role: 'system' | 'user' | 'assistant'; + content: string; + }>; +} + +/** + * Standard output format for AI model responses + * Ensures consistent response structure across providers + */ +export interface ModelOutput { + /** The generated text response */ + text: string; + /** Number of tokens used in the request */ + tokensUsed?: number; + /** Model metadata */ + metadata?: { + model: string; + finishReason?: string; + responseTime?: number; + provider?: string; + }; + /** Raw response from the API */ + rawResponse?: any; + /** Streaming response chunks (if applicable) */ + chunks?: string[]; +} + +/** + * Error types for AI model operations + * Provides structured error handling + */ +export class AIModelError extends Error { + constructor( + message: string, + public code: string, + public statusCode?: number, + public retryable: boolean = false + ) { + super(message); + this.name = 'AIModelError'; + } +} + +/** + * Rate limiting configuration + * Manages request throttling + */ +export interface RateLimitConfig { + maxRequests: number; + windowMs: number; + currentRequests: number; + windowStart: number; +} + +/** + * Provider-specific configuration interfaces + */ +export interface OpenAIConfig extends ExtendedModelConfig { + provider: 'openai'; + modelVersion: string; + organization?: string; +} + +export interface HuggingFaceConfig extends ExtendedModelConfig { + provider: 'huggingface'; + modelVersion: string; + useAuth?: boolean; +} + +export interface CustomModelConfig extends ExtendedModelConfig { + provider: 'custom'; + modelVersion: string; + customEndpoint: string; +} + +/** + * Batch processing configuration + */ +export interface BatchConfig { + /** Maximum number of concurrent requests */ + concurrency?: number; + /** Delay between batch requests */ + batchDelay?: number; + /** Retry failed requests */ + retryFailed?: boolean; +} + +/** + * Model capabilities interface + */ +export interface ModelCapabilities { + /** Supports text generation */ + textGeneration: boolean; + /** Supports chat/conversation */ + chat: boolean; + /** Supports streaming responses */ + streaming: boolean; + /** Supports batch processing */ + batchProcessing: boolean; + /** Maximum context length */ + maxContextLength?: number; + /** Supported languages */ + languages?: string[]; +} diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index dd8a1e4..1318fc5 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -10,6 +10,12 @@ export { AIService } from './ai/service'; export { CreditScorer } from './ai/credit-scorer'; export { FraudDetector } from './ai/fraud-detector'; +// AI Model base class and providers +export { AIModel, AIModelError } from './ai/base-model'; +export { OpenAIModel } from './ai/providers/openai-model'; +export { HuggingFaceModel } from './ai/providers/huggingface-model'; +export { CustomModel } from './ai/providers/custom-model'; + // Form validation utilities export { ValidationRules, validateField, validateFields } from './utils/validation'; export * from './types/form';