diff --git a/docs/tech-specs/ontology.md b/docs/tech-specs/ontology.md new file mode 100644 index 00000000..61cc09e4 --- /dev/null +++ b/docs/tech-specs/ontology.md @@ -0,0 +1,286 @@ +# Ontology Structure Technical Specification + +## Overview + +This specification describes the structure and format of ontologies within the TrustGraph system. Ontologies provide formal knowledge models that define classes, properties, and relationships, supporting reasoning and inference capabilities. The system uses an OWL-inspired configuration format that broadly represents OWL/RDFS concepts while being optimized for TrustGraph's requirements. + +**Naming Convention**: This project uses kebab-case for all identifiers (configuration keys, API endpoints, module names, etc.) rather than snake_case. + +## Goals + +- **Class and Property Management**: Define OWL-like classes with properties, domains, ranges, and type constraints +- **Rich Semantic Support**: Enable comprehensive RDFS/OWL properties including labels, multi-language support, and formal constraints +- **Multi-Ontology Support**: Allow multiple ontologies to coexist and interoperate +- **Validation and Reasoning**: Ensure ontologies conform to OWL-like standards with consistency checking and inference support +- **Standard Compatibility**: Support import/export in standard formats (Turtle, RDF/XML, OWL/XML) while maintaining internal optimization + +## Background + +TrustGraph stores ontologies as configuration items in a flexible key-value system. While the format is inspired by OWL (Web Ontology Language), it is optimized for TrustGraph's specific use cases and does not strictly adhere to all OWL specifications. + +Ontologies in TrustGraph enable: +- Definition of formal object types and their properties +- Specification of property domains, ranges, and type constraints +- Logical reasoning and inference +- Complex relationships and cardinality constraints +- Multi-language support for internationalization + +## Ontology Structure + +### Configuration Storage + +Ontologies are stored as configuration items with the following pattern: +- **Type**: `ontology` +- **Key**: Unique ontology identifier (e.g., `natural-world`, `domain-model`) +- **Value**: Complete ontology in JSON format + +### JSON Structure + +The ontology JSON format consists of four main sections: + +#### 1. Metadata + +Contains administrative and descriptive information about the ontology: + +```json +{ + "metadata": { + "name": "The natural world", + "description": "Ontology covering the natural order", + "version": "1.0.0", + "created": "2025-09-20T12:07:37.068Z", + "modified": "2025-09-20T12:12:20.725Z", + "creator": "current-user", + "namespace": "http://trustgraph.ai/ontologies/natural-world", + "imports": ["http://www.w3.org/2002/07/owl#"] + } +} +``` + +**Fields:** +- `name`: Human-readable name of the ontology +- `description`: Brief description of the ontology's purpose +- `version`: Semantic version number +- `created`: ISO 8601 timestamp of creation +- `modified`: ISO 8601 timestamp of last modification +- `creator`: Identifier of the creating user/system +- `namespace`: Base URI for ontology elements +- `imports`: Array of imported ontology URIs + +#### 2. Classes + +Defines the object types and their hierarchical relationships: + +```json +{ + "classes": { + "animal": { + "uri": "http://trustgraph.ai/ontologies/natural-world#animal", + "type": "owl:Class", + "rdfs:label": [{"value": "Animal", "lang": "en"}], + "rdfs:comment": "An animal", + "rdfs:subClassOf": "lifeform", + "owl:equivalentClass": ["creature"], + "owl:disjointWith": ["plant"], + "dcterms:identifier": "ANI-001" + } + } +} +``` + +**Supported Properties:** +- `uri`: Full URI of the class +- `type`: Always `"owl:Class"` +- `rdfs:label`: Array of language-tagged labels +- `rdfs:comment`: Description of the class +- `rdfs:subClassOf`: Parent class identifier (single inheritance) +- `owl:equivalentClass`: Array of equivalent class identifiers +- `owl:disjointWith`: Array of disjoint class identifiers +- `dcterms:identifier`: Optional external reference identifier + +#### 3. Object Properties + +Properties that link instances to other instances: + +```json +{ + "objectProperties": { + "has-parent": { + "uri": "http://trustgraph.ai/ontologies/natural-world#has-parent", + "type": "owl:ObjectProperty", + "rdfs:label": [{"value": "has parent", "lang": "en"}], + "rdfs:comment": "Links an animal to its parent", + "rdfs:domain": "animal", + "rdfs:range": "animal", + "owl:inverseOf": "parent-of", + "owl:functionalProperty": false + } + } +} +``` + +**Supported Properties:** +- `uri`: Full URI of the property +- `type`: Always `"owl:ObjectProperty"` +- `rdfs:label`: Array of language-tagged labels +- `rdfs:comment`: Description of the property +- `rdfs:domain`: Class identifier that has this property +- `rdfs:range`: Class identifier for property values +- `owl:inverseOf`: Identifier of inverse property +- `owl:functionalProperty`: Boolean indicating at most one value +- `owl:inverseFunctionalProperty`: Boolean for unique identifying properties + +#### 4. Datatype Properties + +Properties that link instances to literal values: + +```json +{ + "datatypeProperties": { + "number-of-legs": { + "uri": "http://trustgraph.ai/ontologies/natural-world#number-of-legs", + "type": "owl:DatatypeProperty", + "rdfs:label": [{"value": "number of legs", "lang": "en"}], + "rdfs:comment": "Count of number of legs of the animal", + "rdfs:domain": "animal", + "rdfs:range": "xsd:nonNegativeInteger", + "owl:functionalProperty": true, + "owl:minCardinality": 0, + "owl:maxCardinality": 1 + } + } +} +``` + +**Supported Properties:** +- `uri`: Full URI of the property +- `type`: Always `"owl:DatatypeProperty"` +- `rdfs:label`: Array of language-tagged labels +- `rdfs:comment`: Description of the property +- `rdfs:domain`: Class identifier that has this property +- `rdfs:range`: XSD datatype for property values +- `owl:functionalProperty`: Boolean indicating at most one value +- `owl:minCardinality`: Minimum number of values (optional) +- `owl:maxCardinality`: Maximum number of values (optional) +- `owl:cardinality`: Exact number of values (optional) + +### Supported XSD Datatypes + +The following XML Schema datatypes are supported for datatype property ranges: + +- `xsd:string` - Text values +- `xsd:integer` - Integer numbers +- `xsd:nonNegativeInteger` - Non-negative integers +- `xsd:float` - Floating point numbers +- `xsd:double` - Double precision numbers +- `xsd:boolean` - True/false values +- `xsd:dateTime` - Date and time values +- `xsd:date` - Date values +- `xsd:anyURI` - URI references + +### Language Support + +Labels and comments support multiple languages using the W3C language tag format: + +```json +{ + "rdfs:label": [ + {"value": "Animal", "lang": "en"}, + {"value": "Tier", "lang": "de"}, + {"value": "Animal", "lang": "es"} + ] +} +``` + +## Example Ontology + +Here's a complete example of a simple ontology: + +```json +{ + "metadata": { + "name": "The natural world", + "description": "Ontology covering the natural order", + "version": "1.0.0", + "created": "2025-09-20T12:07:37.068Z", + "modified": "2025-09-20T12:12:20.725Z", + "creator": "current-user", + "namespace": "http://trustgraph.ai/ontologies/natural-world", + "imports": ["http://www.w3.org/2002/07/owl#"] + }, + "classes": { + "lifeform": { + "uri": "http://trustgraph.ai/ontologies/natural-world#lifeform", + "type": "owl:Class", + "rdfs:label": [{"value": "Lifeform", "lang": "en"}], + "rdfs:comment": "A living thing" + }, + "animal": { + "uri": "http://trustgraph.ai/ontologies/natural-world#animal", + "type": "owl:Class", + "rdfs:label": [{"value": "Animal", "lang": "en"}], + "rdfs:comment": "An animal", + "rdfs:subClassOf": "lifeform" + }, + "cat": { + "uri": "http://trustgraph.ai/ontologies/natural-world#cat", + "type": "owl:Class", + "rdfs:label": [{"value": "Cat", "lang": "en"}], + "rdfs:comment": "A cat", + "rdfs:subClassOf": "animal" + }, + "dog": { + "uri": "http://trustgraph.ai/ontologies/natural-world#dog", + "type": "owl:Class", + "rdfs:label": [{"value": "Dog", "lang": "en"}], + "rdfs:comment": "A dog", + "rdfs:subClassOf": "animal", + "owl:disjointWith": ["cat"] + } + }, + "objectProperties": {}, + "datatypeProperties": { + "number-of-legs": { + "uri": "http://trustgraph.ai/ontologies/natural-world#number-of-legs", + "type": "owl:DatatypeProperty", + "rdfs:label": [{"value": "number-of-legs", "lang": "en"}], + "rdfs:comment": "Count of number of legs of the animal", + "rdfs:range": "xsd:nonNegativeInteger", + "rdfs:domain": "animal" + } + } +} +``` + +## Validation Rules + +### Structural Validation + +1. **URI Consistency**: All URIs should follow the pattern `{namespace}#{identifier}` +2. **Class Hierarchy**: No circular inheritance in `rdfs:subClassOf` +3. **Property Domains/Ranges**: Must reference existing classes or valid XSD types +4. **Disjoint Classes**: Cannot be subclasses of each other +5. **Inverse Properties**: Must be bidirectional if specified + +### Semantic Validation + +1. **Unique Identifiers**: Class and property identifiers must be unique within an ontology +2. **Language Tags**: Must follow BCP 47 language tag format +3. **Cardinality Constraints**: `minCardinality` ≤ `maxCardinality` when both specified +4. **Functional Properties**: Cannot have `maxCardinality` > 1 + +## Import/Export Format Support + +While the internal format is JSON, the system supports conversion to/from standard ontology formats: + +- **Turtle (.ttl)** - Compact RDF serialization +- **RDF/XML (.rdf, .owl)** - W3C standard format +- **OWL/XML (.owx)** - OWL-specific XML format +- **JSON-LD (.jsonld)** - JSON for Linked Data + +## References + +- [OWL 2 Web Ontology Language](https://www.w3.org/TR/owl2-overview/) +- [RDF Schema 1.1](https://www.w3.org/TR/rdf-schema/) +- [XML Schema Datatypes](https://www.w3.org/TR/xmlschema-2/) +- [BCP 47 Language Tags](https://tools.ietf.org/html/bcp47) \ No newline at end of file diff --git a/docs/tech-specs/ontorag.md b/docs/tech-specs/ontorag.md new file mode 100644 index 00000000..ae815e35 --- /dev/null +++ b/docs/tech-specs/ontorag.md @@ -0,0 +1,1443 @@ +# OntoRAG: Ontology-Based Knowledge Extraction and Query Technical Specification + +## Overview + +OntoRAG is an ontology-driven knowledge extraction and query system that enforces strict semantic consistency during both the extraction of knowledge triples from unstructured text and the querying of the resulting knowledge graph. Similar to GraphRAG but with formal ontology constraints, OntoRAG ensures all extracted triples conform to predefined ontological structures and provides semantically-aware querying capabilities. + +The system uses vector similarity matching to dynamically select relevant ontology subsets for both extraction and query operations, enabling focused and contextually appropriate processing while maintaining semantic validity. + +**Service Name**: `kg-extract-ontology` + +## Goals + +- **Ontology-Conformant Extraction**: Ensure all extracted triples strictly conform to loaded ontologies +- **Dynamic Context Selection**: Use embeddings to select relevant ontology subsets for each chunk +- **Semantic Consistency**: Maintain class hierarchies, property domains/ranges, and constraints +- **Efficient Processing**: Use in-memory vector stores for fast ontology element matching +- **Scalable Architecture**: Support multiple concurrent ontologies with different domains + +## Background + +Current knowledge extraction services (`kg-extract-definitions`, `kg-extract-relationships`) operate without formal constraints, potentially producing inconsistent or incompatible triples. OntoRAG addresses this by: + +1. Loading formal ontologies that define valid classes and properties +2. Using embeddings to match text content with relevant ontology elements +3. Constraining extraction to only produce ontology-conformant triples +4. Providing semantic validation of extracted knowledge + +This approach combines the flexibility of neural extraction with the rigor of formal knowledge representation. + +## Technical Design + +### Architecture + +The OntoRAG system consists of the following components: + +``` +┌─────────────────┐ +│ Configuration │ +│ Service │ +└────────┬────────┘ + │ Ontologies + ▼ +┌─────────────────┐ ┌──────────────┐ +│ kg-extract- │────▶│ Embedding │ +│ ontology │ │ Service │ +└────────┬────────┘ └──────────────┘ + │ │ + ▼ ▼ +┌─────────────────┐ ┌──────────────┐ +│ In-Memory │◀────│ Ontology │ +│ Vector Store │ │ Embedder │ +└────────┬────────┘ └──────────────┘ + │ + ▼ +┌─────────────────┐ ┌──────────────┐ +│ Sentence │────▶│ Chunker │ +│ Splitter │ │ Service │ +└────────┬────────┘ └──────────────┘ + │ + ▼ +┌─────────────────┐ ┌──────────────┐ +│ Ontology │────▶│ Vector │ +│ Selector │ │ Search │ +└────────┬────────┘ └──────────────┘ + │ + ▼ +┌─────────────────┐ ┌──────────────┐ +│ Prompt │────▶│ Prompt │ +│ Constructor │ │ Service │ +└────────┬────────┘ └──────────────┘ + │ + ▼ +┌─────────────────┐ +│ Triple Output │ +└─────────────────┘ +``` + +### Component Details + +#### 1. Ontology Loader + +**Purpose**: Retrieves and parses ontology configurations from the configuration service at service startup. + +**Algorithm Description**: +The Ontology Loader connects to the configuration service and requests all configuration items of type "ontology". For each ontology configuration found, it parses the JSON structure containing metadata, classes, object properties, and datatype properties. These parsed ontologies are stored in memory as structured objects that can be efficiently accessed during the extraction process. The loader runs once during service initialisation and can optionally refresh ontologies at configured intervals to pick up updates. + +**Key Operations**: +- Query configuration service for all ontology-type configurations +- Parse JSON ontology structures into internal object models +- Validate ontology structure and consistency +- Cache parsed ontologies in memory for fast access + +Loads ontologies from the configuration service during initialisation: + +```python +class OntologyLoader: + def __init__(self, config_service): + self.config_service = config_service + self.ontologies = {} + + async def load_ontologies(self): + # Fetch all ontology configurations + configs = await self.config_service.get_configs(type="ontology") + + for config_id, ontology_data in configs: + self.ontologies[config_id] = Ontology( + metadata=ontology_data['metadata'], + classes=ontology_data['classes'], + object_properties=ontology_data['objectProperties'], + datatype_properties=ontology_data['datatypeProperties'] + ) + + return self.ontologies +``` + +#### 2. Ontology Embedder + +**Purpose**: Creates vector embeddings for all ontology elements to enable semantic similarity matching. + +**Algorithm Description**: +The Ontology Embedder processes each element in the loaded ontologies (classes, object properties, and datatype properties) and generates vector embeddings using an embedding service. For each element, it combines the element's identifier with its description (from rdfs:comment) to create a text representation. This text is then converted to a high-dimensional vector embedding that captures its semantic meaning. These embeddings are stored in an in-memory vector store along with metadata about the element type, source ontology, and full definition. This preprocessing step happens once at startup, creating a searchable index of all ontology concepts. + +**Key Operations**: +- Concatenate element IDs with their descriptions for rich semantic representation +- Generate embeddings via external embedding service (e.g., text-embedding-3-small) +- Store embeddings with comprehensive metadata in vector store +- Index by ontology, element type, and element ID for efficient retrieval + +Generates embeddings for ontology elements and stores them in an in-memory vector store: + +```python +class OntologyEmbedder: + def __init__(self, embedding_service, vector_store): + self.embedding_service = embedding_service + self.vector_store = vector_store + + async def embed_ontologies(self, ontologies): + for onto_id, ontology in ontologies.items(): + # Embed classes + for class_id, class_def in ontology.classes.items(): + text = f"{class_id} {class_def.get('rdfs:comment', '')}" + embedding = await self.embedding_service.embed(text) + + self.vector_store.add( + id=f"{onto_id}:class:{class_id}", + embedding=embedding, + metadata={ + 'type': 'class', + 'ontology': onto_id, + 'element': class_id, + 'definition': class_def + } + ) + + # Embed properties (object and datatype) + for prop_type in ['objectProperties', 'datatypeProperties']: + for prop_id, prop_def in getattr(ontology, prop_type).items(): + text = f"{prop_id} {prop_def.get('rdfs:comment', '')}" + embedding = await self.embedding_service.embed(text) + + self.vector_store.add( + id=f"{onto_id}:{prop_type}:{prop_id}", + embedding=embedding, + metadata={ + 'type': prop_type, + 'ontology': onto_id, + 'element': prop_id, + 'definition': prop_def + } + ) +``` + +#### 3. Sentence Splitter + +**Purpose**: Decomposes text chunks into fine-grained segments for precise ontology matching. + +**Algorithm Description**: +The Sentence Splitter takes incoming text chunks and breaks them down into smaller, more manageable units. First, it uses natural language processing techniques (via NLTK or spaCy) to identify sentence boundaries, handling edge cases like abbreviations and decimal points. Then, for each sentence, it extracts meaningful phrases including noun phrases (e.g., "the red car"), verb phrases (e.g., "quickly ran"), and named entities. This multi-level segmentation ensures that both complete thoughts (sentences) and specific concepts (phrases) can be matched against ontology elements. Each segment is tagged with its type and position information to maintain context. + +**Key Operations**: +- Split text into sentences using NLP sentence detection +- Extract noun phrases and verb phrases from each sentence +- Identify named entities and key terms +- Maintain hierarchical relationship between sentences and their phrases +- Preserve positional information for context reconstruction + +Breaks incoming chunks into smaller sentences and phrases for granular matching: + +```python +class SentenceSplitter: + def __init__(self): + # Use NLTK or spaCy for sophisticated splitting + self.sentence_detector = SentenceDetector() + self.phrase_extractor = PhraseExtractor() + + def split_chunk(self, chunk_text): + sentences = self.sentence_detector.split(chunk_text) + + segments = [] + for sentence in sentences: + # Add full sentence + segments.append({ + 'text': sentence, + 'type': 'sentence', + 'position': len(segments) + }) + + # Extract noun phrases and verb phrases + phrases = self.phrase_extractor.extract(sentence) + for phrase in phrases: + segments.append({ + 'text': phrase, + 'type': 'phrase', + 'parent_sentence': sentence, + 'position': len(segments) + }) + + return segments +``` + +#### 4. Ontology Selector + +**Purpose**: Identifies the most relevant subset of ontology elements for the current text chunk. + +**Algorithm Description**: +The Ontology Selector performs semantic matching between text segments and ontology elements using vector similarity search. For each sentence and phrase from the text chunk, it generates an embedding and searches the vector store for the most similar ontology elements. The search uses cosine similarity with a configurable threshold (e.g., 0.7) to find semantically related concepts. After collecting all relevant elements, it performs dependency resolution to ensure completeness - if a class is selected, its parent classes are included; if a property is selected, its domain and range classes are added. This creates a minimal but complete ontology subset that contains all necessary elements for valid triple extraction while avoiding irrelevant concepts that could confuse the extraction process. + +**Key Operations**: +- Generate embeddings for each text segment (sentences and phrases) +- Perform k-nearest neighbor search in the vector store +- Apply similarity threshold to filter weak matches +- Resolve dependencies (parent classes, domains, ranges) +- Construct coherent ontology subset with all required relationships +- Deduplicate elements appearing multiple times + +Uses vector similarity to find relevant ontology elements for each text segment: + +```python +class OntologySelector: + def __init__(self, embedding_service, vector_store): + self.embedding_service = embedding_service + self.vector_store = vector_store + + async def select_ontology_subset(self, segments, top_k=10): + relevant_elements = set() + + for segment in segments: + # Get embedding for segment + embedding = await self.embedding_service.embed(segment['text']) + + # Search for similar ontology elements + results = self.vector_store.search( + embedding=embedding, + top_k=top_k, + threshold=0.7 # Similarity threshold + ) + + for result in results: + relevant_elements.add(( + result['metadata']['ontology'], + result['metadata']['type'], + result['metadata']['element'], + result['metadata']['definition'] + )) + + # Build ontology subset + return self._build_subset(relevant_elements) + + def _build_subset(self, elements): + # Include selected elements and their dependencies + # (parent classes, domain/range references, etc.) + subset = { + 'classes': {}, + 'objectProperties': {}, + 'datatypeProperties': {} + } + + for onto_id, elem_type, elem_id, definition in elements: + if elem_type == 'class': + subset['classes'][elem_id] = definition + # Include parent classes + if 'rdfs:subClassOf' in definition: + parent = definition['rdfs:subClassOf'] + # Recursively add parent from full ontology + elif elem_type == 'objectProperties': + subset['objectProperties'][elem_id] = definition + # Include domain and range classes + elif elem_type == 'datatypeProperties': + subset['datatypeProperties'][elem_id] = definition + + return subset +``` + +#### 5. Prompt Constructor + +**Purpose**: Creates structured prompts that guide the LLM to extract only ontology-conformant triples. + +**Algorithm Description**: +The Prompt Constructor assembles a carefully formatted prompt that constrains the LLM's extraction to the selected ontology subset. It takes the relevant classes and properties identified by the Ontology Selector and formats them into clear instructions. Classes are presented with their hierarchical relationships and descriptions. Properties are shown with their domain and range constraints, making explicit what types of entities they can connect. The prompt includes strict rules about using only the provided ontology elements and respecting all constraints. The original text chunk is then appended, and the LLM is instructed to extract triples in the format (subject, predicate, object). This structured approach ensures the LLM understands both what to look for and what constraints to respect. + +**Key Operations**: +- Format classes with parent relationships and descriptions +- Format properties with domain/range constraints +- Include explicit extraction rules and constraints +- Specify output format for consistent parsing +- Balance prompt size with completeness of ontology information + +Builds prompts for the extraction service with ontology constraints: + +```python +class PromptConstructor: + def __init__(self): + self.template = """ +Extract knowledge triples from the following text using ONLY the provided ontology elements. + +ONTOLOGY CLASSES: +{classes} + +OBJECT PROPERTIES (connect entities): +{object_properties} + +DATATYPE PROPERTIES (entity attributes): +{datatype_properties} + +RULES: +1. Only use classes defined above for entity types +2. Only use properties defined above for relationships and attributes +3. Respect domain and range constraints +4. Output format: (subject, predicate, object) + +TEXT: +{text} + +TRIPLES: +""" + + def build_prompt(self, chunk_text, ontology_subset): + classes_str = self._format_classes(ontology_subset['classes']) + obj_props_str = self._format_properties( + ontology_subset['objectProperties'], + 'object' + ) + dt_props_str = self._format_properties( + ontology_subset['datatypeProperties'], + 'datatype' + ) + + return self.template.format( + classes=classes_str, + object_properties=obj_props_str, + datatype_properties=dt_props_str, + text=chunk_text + ) + + def _format_classes(self, classes): + lines = [] + for class_id, definition in classes.items(): + comment = definition.get('rdfs:comment', '') + parent = definition.get('rdfs:subClassOf', 'Thing') + lines.append(f"- {class_id} (subclass of {parent}): {comment}") + return '\n'.join(lines) + + def _format_properties(self, properties, prop_type): + lines = [] + for prop_id, definition in properties.items(): + comment = definition.get('rdfs:comment', '') + domain = definition.get('rdfs:domain', 'Any') + range_val = definition.get('rdfs:range', 'Any') + lines.append(f"- {prop_id} ({domain} -> {range_val}): {comment}") + return '\n'.join(lines) +``` + +#### 6. Main Extractor Service + +**Purpose**: Coordinates all components to perform end-to-end ontology-based triple extraction. + +**Algorithm Description**: +The Main Extractor Service is the orchestration layer that manages the complete extraction workflow. During initialisation, it loads all ontologies and pre-computes their embeddings, creating the searchable vector index. When a text chunk arrives for processing, it coordinates the pipeline: first splitting the text into segments, then finding relevant ontology elements through vector search, constructing a constrained prompt, calling the LLM service, and finally parsing and validating the response. The service ensures that each extracted triple conforms to the ontology by validating that subjects and objects are valid class instances, predicates are valid properties, and all domain/range constraints are satisfied. Only validated triples that fully conform to the ontology are returned. + +**Extraction Pipeline**: +1. Receive text chunk for processing +2. Split into sentences and phrases for granular analysis +3. Search vector store to find relevant ontology concepts +4. Build ontology subset including dependencies +5. Construct prompt with ontology constraints and text +6. Call LLM service for triple extraction +7. Parse response into structured triples +8. Validate each triple against ontology rules +9. Return only valid, ontology-conformant triples + +Orchestrates the complete extraction pipeline: + +```python +class KgExtractOntology: + def __init__(self, config): + self.loader = OntologyLoader(config['config_service']) + self.embedder = OntologyEmbedder( + config['embedding_service'], + InMemoryVectorStore() + ) + self.splitter = SentenceSplitter() + self.selector = OntologySelector( + config['embedding_service'], + self.embedder.vector_store + ) + self.prompt_builder = PromptConstructor() + self.prompt_service = config['prompt_service'] + + async def initialize(self): + # Load and embed ontologies at startup + ontologies = await self.loader.load_ontologies() + await self.embedder.embed_ontologies(ontologies) + + async def extract(self, chunk): + # Split chunk into segments + segments = self.splitter.split_chunk(chunk['text']) + + # Select relevant ontology subset + ontology_subset = await self.selector.select_ontology_subset(segments) + + # Build extraction prompt + prompt = self.prompt_builder.build_prompt( + chunk['text'], + ontology_subset + ) + + # Call prompt service + response = await self.prompt_service.generate(prompt) + + # Parse and validate triples + triples = self.parse_triples(response) + validated_triples = self.validate_triples(triples, ontology_subset) + + return validated_triples + + def parse_triples(self, response): + # Parse LLM response into structured triples + triples = [] + for line in response.split('\n'): + if line.strip().startswith('(') and line.strip().endswith(')'): + # Parse (subject, predicate, object) + parts = line.strip()[1:-1].split(',') + if len(parts) == 3: + triples.append({ + 'subject': parts[0].strip(), + 'predicate': parts[1].strip(), + 'object': parts[2].strip() + }) + return triples + + def validate_triples(self, triples, ontology_subset): + # Validate against ontology constraints + validated = [] + for triple in triples: + if self._is_valid(triple, ontology_subset): + validated.append(triple) + return validated +``` + +### Configuration + +The service loads configuration on startup: + +```yaml +kg-extract-ontology: + embedding_model: "text-embedding-3-small" + vector_store: + type: "in-memory" + similarity_threshold: 0.7 + top_k: 10 + sentence_splitter: + model: "nltk" + max_sentence_length: 512 + prompt_service: + endpoint: "http://prompt-service:8080" + model: "gpt-4" + temperature: 0.1 + ontology_refresh_interval: 300 # seconds +``` + +### Data Flow + +1. **Initialisation Phase**: + - Load ontologies from configuration service + - Generate embeddings for all ontology elements + - Store embeddings in in-memory vector store + +2. **Extraction Phase** (per chunk): + - Split chunk into sentences and phrases + - Compute embeddings for each segment + - Search vector store for relevant ontology elements + - Build ontology subset with selected elements + - Construct prompt with chunk text and ontology subset + - Call prompt service for extraction + - Parse and validate returned triples + - Output conformant triples + +### In-Memory Vector Store + +**Purpose**: Provides fast, memory-based similarity search for ontology element matching. + +**Recommended Implementation: FAISS** + +The system should use **FAISS (Facebook AI Similarity Search)** as the primary vector store implementation for the following reasons: + +1. **Performance**: Optimised for similarity search with microsecond latency, critical for real-time query processing +2. **Memory Efficiency**: Multiple index types (Flat, IVF, HNSW) allow memory/speed tradeoffs based on ontology size +3. **Scalability**: Efficiently handles hundreds to tens of thousands of ontology elements +4. **Production Ready**: Battle-tested in production environments with excellent stability +5. **Python Integration**: Native Python bindings with numpy compatibility for seamless integration + +**FAISS Implementation**: + +```python +import faiss +import numpy as np + +class FAISSVectorStore: + def __init__(self, dimension=1536, index_type='flat'): + """ + Initialize FAISS vector store. + + Args: + dimension: Embedding dimension (1536 for text-embedding-3-small) + index_type: 'flat' for exact search, 'ivf' for larger datasets + """ + self.dimension = dimension + self.metadata = [] + self.ids = [] + + if index_type == 'flat': + # Exact search - best for ontologies with <10k elements + self.index = faiss.IndexFlatIP(dimension) + else: + # Approximate search - for larger ontologies + quantizer = faiss.IndexFlatIP(dimension) + self.index = faiss.IndexIVFFlat(quantizer, dimension, 100) + self.index.train(np.random.randn(1000, dimension).astype('float32')) + + def add(self, id, embedding, metadata): + """Add single embedding with metadata.""" + # Normalize for cosine similarity + embedding = embedding / np.linalg.norm(embedding) + self.index.add(np.array([embedding], dtype=np.float32)) + self.metadata.append(metadata) + self.ids.append(id) + + def add_batch(self, ids, embeddings, metadata_list): + """Batch add for initial ontology loading.""" + # Normalize all embeddings + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + normalized = embeddings / norms + self.index.add(normalized.astype(np.float32)) + self.metadata.extend(metadata_list) + self.ids.extend(ids) + + def search(self, embedding, top_k=10, threshold=0.0): + """Search for similar vectors.""" + # Normalize query + embedding = embedding / np.linalg.norm(embedding) + + # Search + scores, indices = self.index.search( + np.array([embedding], dtype=np.float32), + min(top_k, self.index.ntotal) + ) + + # Filter by threshold and format results + results = [] + for score, idx in zip(scores[0], indices[0]): + if idx >= 0 and score >= threshold: # FAISS returns -1 for empty slots + results.append({ + 'id': self.ids[idx], + 'score': float(score), + 'metadata': self.metadata[idx] + }) + + return results + + def clear(self): + """Reset the store.""" + self.index.reset() + self.metadata = [] + self.ids = [] + + def size(self): + """Return number of stored vectors.""" + return self.index.ntotal +``` + +**Fallback Implementation (NumPy)**: + +For development or small-scale deployments, a simple NumPy implementation can be used: + +```python +class SimpleVectorStore: + """Fallback implementation using NumPy - suitable for <1000 elements.""" + def __init__(self): + self.embeddings = [] + self.metadata = [] + self.ids = [] + + def add(self, id, embedding, metadata): + self.embeddings.append(embedding / np.linalg.norm(embedding)) + self.metadata.append(metadata) + self.ids.append(id) + + def search(self, embedding, top_k=10, threshold=0.0): + if not self.embeddings: + return [] + + # Normalize and compute similarities + embedding = embedding / np.linalg.norm(embedding) + similarities = np.dot(self.embeddings, embedding) + + # Get top-k indices + top_indices = np.argsort(similarities)[::-1][:top_k] + + # Build results + results = [] + for idx in top_indices: + if similarities[idx] >= threshold: + results.append({ + 'id': self.ids[idx], + 'score': float(similarities[idx]), + 'metadata': self.metadata[idx] + }) + + return results +``` + +### Ontology Subset Selection Algorithm + +**Purpose**: Dynamically selects the minimal relevant portion of the ontology for each text chunk. + +**Detailed Algorithm Steps**: + +1. **Text Segmentation**: + - Split the input chunk into sentences using NLP sentence detection + - Extract noun phrases, verb phrases, and named entities from each sentence + - Create a hierarchical structure of segments preserving context + +2. **Embedding Generation**: + - Generate vector embeddings for each text segment (sentences and phrases) + - Use the same embedding model as used for ontology elements + - Cache embeddings for repeated segments to improve performance + +3. **Similarity Search**: + - For each text segment embedding, search the vector store + - Retrieve top-k (e.g., 10) most similar ontology elements + - Apply similarity threshold (e.g., 0.7) to filter weak matches + - Aggregate results across all segments, tracking match frequencies + +4. **Dependency Resolution**: + - For each selected class, recursively include all parent classes up to root + - For each selected property, include its domain and range classes + - For inverse properties, ensure both directions are included + - Add equivalent classes if they exist in the ontology + +5. **Subset Construction**: + - Deduplicate collected elements while preserving relationships + - Organise into classes, object properties, and datatype properties + - Ensure all constraints and relationships are preserved + - Create a self-contained mini-ontology that is valid and complete + +**Example Walkthrough**: +Given text: "The brown dog chased the white cat up the tree." +- Segments: ["brown dog", "white cat", "tree", "chased"] +- Matched elements: [dog (class), cat (class), animal (parent), chases (property)] +- Dependencies: [animal (parent of dog and cat), lifeform (parent of animal)] +- Final subset: Complete mini-ontology with animal hierarchy and chase relationship + +### Triple Validation + +**Purpose**: Ensures all extracted triples strictly conform to ontology constraints. + +**Validation Algorithm**: + +1. **Class Validation**: + - Verify that subjects are instances of classes defined in the ontology subset + - For object properties, verify that objects are also valid class instances + - Check class names against the ontology's class dictionary + - Handle class hierarchies - instances of subclasses are valid for parent class constraints + +2. **Property Validation**: + - Confirm predicates correspond to properties in the ontology subset + - Distinguish between object properties (entity-to-entity) and datatype properties (entity-to-literal) + - Verify property names match exactly (considering namespace if present) + +3. **Domain/Range Checking**: + - For each property used as predicate, retrieve its domain and range + - Verify the subject's type matches or inherits from the property's domain + - Verify the object's type matches or inherits from the property's range + - For datatype properties, verify the object is a literal of the correct XSD type + +4. **Cardinality Validation**: + - Track property usage counts per subject + - Check minimum cardinality - ensure required properties are present + - Check maximum cardinality - ensure property isn't used too many times + - For functional properties, ensure at most one value per subject + +5. **Datatype Validation**: + - Parse literal values according to their declared XSD types + - Validate integers are valid numbers, dates are properly formatted, etc. + - Check string patterns if regex constraints are defined + - Ensure URIs are well-formed for xsd:anyURI types + +**Validation Example**: +Triple: ("Buddy", "has-owner", "John") +- Check "Buddy" is typed as a class that can have "has-owner" property +- Check "has-owner" exists in the ontology +- Verify domain constraint: subject must be of type "Pet" or subclass +- Verify range constraint: object must be of type "Person" or subclass +- If valid, add to output; if invalid, log violation and skip + +## Performance Considerations + +### Optimisation Strategies + +1. **Embedding Caching**: Cache embeddings for frequently used text segments +2. **Batch Processing**: Process multiple segments in parallel +3. **Vector Store Indexing**: Use approximate nearest neighbor algorithms for large ontologies +4. **Prompt Optimisation**: Minimise prompt size by including only essential ontology elements +5. **Result Caching**: Cache extraction results for identical chunks + +### Scalability + +- **Horizontal Scaling**: Multiple extractor instances with shared ontology cache +- **Ontology Partitioning**: Split large ontologies by domain +- **Streaming Processing**: Process chunks as they arrive without batching +- **Memory Management**: Periodic cleanup of unused embeddings + +## Error Handling + +### Failure Scenarios + +1. **Missing Ontologies**: Fallback to unconstrained extraction +2. **Embedding Service Failure**: Use cached embeddings or skip semantic matching +3. **Prompt Service Timeout**: Retry with exponential backoff +4. **Invalid Triple Format**: Log and skip malformed triples +5. **Ontology Inconsistencies**: Report conflicts and use most specific valid elements + +### Monitoring + +Key metrics to track: + +- Ontology load time and memory usage +- Embedding generation latency +- Vector search performance +- Prompt service response time +- Triple extraction accuracy +- Ontology conformance rate + +## Migration Path + +### From Existing Extractors + +1. **Parallel Operation**: Run alongside existing extractors initially +2. **Gradual Rollout**: Start with specific document types +3. **Quality Comparison**: Compare output quality with existing extractors +4. **Full Migration**: Replace existing extractors once quality verified + +### Ontology Development + +1. **Bootstrap from Existing**: Generate initial ontologies from existing knowledge +2. **Iterative Refinement**: Refine based on extraction patterns +3. **Domain Expert Review**: Validate with subject matter experts +4. **Continuous Improvement**: Update based on extraction feedback + +## Ontology-Sensitive Query Service + +### Overview + +The ontology-sensitive query service provides multiple query paths to support different backend graph stores. It leverages ontology knowledge for precise, semantically-aware question answering across both Cassandra (via SPARQL) and Cypher-based graph stores (Neo4j, Memgraph, FalkorDB). + +**Service Components**: +- `onto-query-sparql`: Converts natural language to SPARQL for Cassandra +- `sparql-cassandra`: SPARQL query layer for Cassandra using rdflib +- `onto-query-cypher`: Converts natural language to Cypher for graph databases +- `cypher-executor`: Cypher query execution for Neo4j/Memgraph/FalkorDB + +### Architecture + +``` + ┌─────────────────┐ + │ User Query │ + └────────┬────────┘ + │ + ▼ + ┌─────────────────┐ ┌──────────────┐ + │ Question │────▶│ Sentence │ + │ Analyser │ │ Splitter │ + └────────┬────────┘ └──────────────┘ + │ + ▼ + ┌─────────────────┐ ┌──────────────┐ + │ Ontology │────▶│ Vector │ + │ Matcher │ │ Store │ + └────────┬────────┘ └──────────────┘ + │ + ▼ + ┌─────────────────┐ + │ Backend Router │ + └────────┬────────┘ + │ + ┌───────────┴───────────┐ + │ │ + ▼ ▼ + ┌─────────────────┐ ┌─────────────────┐ + │ onto-query- │ │ onto-query- │ + │ sparql │ │ cypher │ + └────────┬────────┘ └────────┬────────┘ + │ │ + ▼ ▼ + ┌─────────────────┐ ┌─────────────────┐ + │ SPARQL │ │ Cypher │ + │ Generator │ │ Generator │ + └────────┬────────┘ └────────┬────────┘ + │ │ + ▼ ▼ + ┌─────────────────┐ ┌─────────────────┐ + │ sparql- │ │ cypher- │ + │ cassandra │ │ executor │ + └────────┬────────┘ └────────┬────────┘ + │ │ + ▼ ▼ + ┌─────────────────┐ ┌─────────────────┐ + │ Cassandra │ │ Neo4j/Memgraph/ │ + │ │ │ FalkorDB │ + └────────┬────────┘ └────────┬────────┘ + │ │ + └────────────┬───────────────┘ + │ + ▼ + ┌─────────────────┐ ┌──────────────┐ + │ Answer │────▶│ Prompt │ + │ Generator │ │ Service │ + └────────┬────────┘ └──────────────┘ + │ + ▼ + ┌─────────────────┐ + │ Final Answer │ + └─────────────────┘ +``` + +### Query Processing Pipeline + +#### 1. Question Analyser + +**Purpose**: Decomposes user questions into semantic components for ontology matching. + +**Algorithm Description**: +The Question Analyser takes the incoming natural language question and breaks it down into meaningful segments using the same sentence splitting approach as the extraction pipeline. It identifies key entities, relationships, and constraints mentioned in the question. Each segment is analysed for question type (factual, aggregation, comparison, etc.) and the expected answer format. This decomposition helps identify which parts of the ontology are most relevant for answering the question. + +**Key Operations**: +- Split question into sentences and phrases +- Identify question type and intent +- Extract mentioned entities and relationships +- Detect constraints and filters in the question +- Determine expected answer format + +#### 2. Ontology Matcher for Queries + +**Purpose**: Identifies the relevant ontology subset needed to answer the question. + +**Algorithm Description**: +Similar to the extraction pipeline's Ontology Selector, but optimised for question answering. The matcher generates embeddings for question segments and searches the vector store for relevant ontology elements. However, it focuses on finding concepts that would be useful for query construction rather than extraction. It expands the selection to include related properties that might be traversed during graph exploration, even if not explicitly mentioned in the question. For example, if asked about "employees," it might include properties like "works-for," "manages," and "reports-to" that could be relevant for finding employee information. + +**Matching Strategy**: +- Embed question segments +- Find directly mentioned ontology concepts +- Include properties that connect mentioned classes +- Add inverse and related properties for traversal +- Include parent/child classes for hierarchical queries +- Build query-focused ontology partition + +#### 3. Backend Router + +**Purpose**: Routes queries to the appropriate backend-specific query path based on configuration. + +**Algorithm Description**: +The Backend Router examines the system configuration to determine which graph backend is active (Cassandra or Cypher-based). It routes the question and ontology partition to the appropriate query generation service. The router can also support load balancing across multiple backends or fallback mechanisms if the primary backend is unavailable. + +**Routing Logic**: +- Check configured backend type from system settings +- Route to `onto-query-sparql` for Cassandra backends +- Route to `onto-query-cypher` for Neo4j/Memgraph/FalkorDB +- Support multi-backend configurations with query distribution +- Handle failover and load balancing scenarios + +#### 4. SPARQL Query Generation (`onto-query-sparql`) + +**Purpose**: Converts natural language questions to SPARQL queries for Cassandra execution. + +**Algorithm Description**: +The SPARQL query generator takes the question and ontology partition and constructs a SPARQL query optimised for execution against the Cassandra backend. It uses the prompt service with a SPARQL-specific template that includes RDF/OWL semantics. The generator understands SPARQL patterns like property paths, optional clauses, and filters that can efficiently translate to Cassandra operations. + +**SPARQL Generation Prompt Template**: +``` +Generate a SPARQL query for the following question using the provided ontology. + +ONTOLOGY CLASSES: +{classes} + +ONTOLOGY PROPERTIES: +{properties} + +RULES: +- Use proper RDF/OWL semantics +- Include relevant prefixes +- Use property paths for hierarchical queries +- Add FILTER clauses for constraints +- Optimise for Cassandra backend + +QUESTION: {question} + +SPARQL QUERY: +``` + +#### 5. Cypher Query Generation (`onto-query-cypher`) + +**Purpose**: Converts natural language questions to Cypher queries for graph databases. + +**Algorithm Description**: +The Cypher query generator creates native Cypher queries optimised for Neo4j, Memgraph, and FalkorDB. It maps ontology classes to node labels and properties to relationships, using Cypher's pattern matching syntax. The generator includes Cypher-specific optimisations like relationship direction hints, index usage, and query planning hints. + +**Cypher Generation Prompt Template**: +``` +Generate a Cypher query for the following question using the provided ontology. + +NODE LABELS (from classes): +{classes} + +RELATIONSHIP TYPES (from properties): +{properties} + +RULES: +- Use MATCH patterns for graph traversal +- Include WHERE clauses for filters +- Use aggregation functions when needed +- Optimise for graph database performance +- Consider index hints for large datasets + +QUESTION: {question} + +CYPHER QUERY: +``` + +#### 6. SPARQL-Cassandra Query Engine (`sparql-cassandra`) + +**Purpose**: Executes SPARQL queries against Cassandra using Python rdflib. + +**Algorithm Description**: +The SPARQL-Cassandra engine implements a SPARQL processor using Python's rdflib library with a custom Cassandra backend store. It translates SPARQL graph patterns into appropriate Cassandra CQL queries, handling joins, filters, and aggregations. The engine maintains an RDF-to-Cassandra mapping that preserves the semantic structure while optimising for Cassandra's column-family storage model. + +**Implementation Features**: +- rdflib Store interface implementation for Cassandra +- SPARQL 1.1 query support with common patterns +- Efficient translation of triple patterns to CQL +- Support for property paths and hierarchical queries +- Result streaming for large datasets +- Connection pooling and query caching + +**Example Translation**: +```sparql +SELECT ?animal WHERE { + ?animal rdf:type :Animal . + ?animal :hasOwner "John" . +} +``` +Translates to optimised Cassandra queries leveraging indexes and partition keys. + +#### 7. Cypher Query Executor (`cypher-executor`) + +**Purpose**: Executes Cypher queries against Neo4j, Memgraph, and FalkorDB. + +**Algorithm Description**: +The Cypher executor provides a unified interface for executing Cypher queries across different graph databases. It handles database-specific connection protocols, query optimisation hints, and result format normalisation. The executor includes retry logic, connection pooling, and transaction management appropriate for each database type. + +**Multi-Database Support**: +- **Neo4j**: Bolt protocol, transaction functions, index hints +- **Memgraph**: Custom protocol, streaming results, analytical queries +- **FalkorDB**: Redis protocol adaptation, in-memory optimisations + +**Execution Features**: +- Database-agnostic connection management +- Query validation and syntax checking +- Timeout and resource limit enforcement +- Result pagination and streaming +- Performance monitoring per database type +- Automatic failover between database instances + +#### 8. Answer Generator + +**Purpose**: Synthesises a natural language answer from query results. + +**Algorithm Description**: +The Answer Generator takes the structured query results and the original question, then uses the prompt service to generate a comprehensive answer. Unlike simple template-based responses, it uses an LLM to interpret the graph data in the context of the question, handling complex relationships, aggregations, and inferences. The generator can explain its reasoning by referencing the ontology structure and the specific triples retrieved from the graph. + +**Answer Generation Process**: +- Format query results into structured context +- Include relevant ontology definitions for clarity +- Construct prompt with question and results +- Generate natural language answer via LLM +- Validate answer against query intent +- Add citations to specific graph entities if needed + +### Integration with Existing Services + +#### Relationship with GraphRAG + +- **Complementary**: onto-query provides semantic precision while GraphRAG provides broad coverage +- **Shared Infrastructure**: Both use the same knowledge graph and prompt services +- **Query Routing**: System can route queries to most appropriate service based on question type +- **Hybrid Mode**: Can combine both approaches for comprehensive answers + +#### Relationship with OntoRAG Extraction + +- **Shared Ontologies**: Uses same ontology configurations loaded by kg-extract-ontology +- **Shared Vector Store**: Reuses the in-memory embeddings from extraction service +- **Consistent Semantics**: Queries operate on graphs built with same ontological constraints + +### Query Examples + +#### Example 1: Simple Entity Query +**Question**: "What animals are mammals?" +**Ontology Match**: [animal, mammal, subClassOf] +**Generated Query**: +```cypher +MATCH (a:animal)-[:subClassOf*]->(m:mammal) +RETURN a.name +``` + +#### Example 2: Relationship Query +**Question**: "Which documents were authored by John Smith?" +**Ontology Match**: [document, person, has-author] +**Generated Query**: +```cypher +MATCH (d:document)-[:has-author]->(p:person {name: "John Smith"}) +RETURN d.title, d.date +``` + +#### Example 3: Aggregation Query +**Question**: "How many legs do cats have?" +**Ontology Match**: [cat, number-of-legs (datatype property)] +**Generated Query**: +```cypher +MATCH (c:cat) +RETURN c.name, c.number_of_legs +``` + +### Configuration + +```yaml +onto-query: + embedding_model: "text-embedding-3-small" + vector_store: + shared_with_extractor: true # Reuse kg-extract-ontology's store + query_builder: + model: "gpt-4" + temperature: 0.1 + max_query_length: 1000 + graph_executor: + timeout: 30000 # ms + max_results: 1000 + answer_generator: + model: "gpt-4" + temperature: 0.3 + max_tokens: 500 +``` + +### Performance Optimisations + +#### Query Optimisation + +- **Ontology Pruning**: Only include necessary ontology elements in prompts +- **Query Caching**: Cache frequently asked questions and their queries +- **Result Caching**: Store results for identical queries within time window +- **Batch Processing**: Handle multiple related questions in single graph traversal + +#### Scalability Considerations + +- **Distributed Execution**: Parallelise subqueries across graph partitions +- **Incremental Results**: Stream results for large datasets +- **Load Balancing**: Distribute query load across multiple service instances +- **Resource Pools**: Manage connection pools to graph databases + +### Error Handling + +#### Failure Scenarios + +1. **Invalid Query Generation**: Fallback to GraphRAG or simple keyword search +2. **Ontology Mismatch**: Expand search to broader ontology subset +3. **Query Timeout**: Simplify query or increase timeout +4. **Empty Results**: Suggest query reformulation or related questions +5. **LLM Service Failure**: Use cached queries or template-based responses + +### Monitoring Metrics + +- Question complexity distribution +- Ontology partition sizes +- Query generation success rate +- Graph query execution time +- Answer quality scores +- Cache hit rates +- Error frequencies by type + +## Future Enhancements + +1. **Ontology Learning**: Automatically extend ontologies based on extraction patterns +2. **Confidence Scoring**: Assign confidence scores to extracted triples +3. **Explanation Generation**: Provide reasoning for triple extraction +4. **Active Learning**: Request human validation for uncertain extractions + +## Security Considerations + +1. **Prompt Injection Prevention**: Sanitise chunk text before prompt construction +2. **Resource Limits**: Cap memory usage for vector store +3. **Rate Limiting**: Limit extraction requests per client +4. **Audit Logging**: Track all extraction requests and results + +## Testing Strategy + +### Unit Testing + +- Ontology loader with various formats +- Embedding generation and storage +- Sentence splitting algorithms +- Vector similarity calculations +- Triple parsing and validation + +### Integration Testing + +- End-to-end extraction pipeline +- Configuration service integration +- Prompt service interaction +- Concurrent extraction handling + +### Performance Testing + +- Large ontology handling (1000+ classes) +- High-volume chunk processing +- Memory usage under load +- Latency benchmarks + +## Delivery Plan + +### Overview + +The OntoRAG system will be delivered in four major phases, with each phase providing incremental value while building toward the complete system. The plan focuses on establishing core extraction capabilities first, then adding query functionality, followed by optimizations and advanced features. + +### Phase 1: Foundation and Core Extraction + +**Goal**: Establish the basic ontology-driven extraction pipeline with simple vector matching. + +#### Step 1.1: Ontology Management Foundation +- Implement ontology configuration loader (`OntologyLoader`) +- Parse and validate ontology JSON structures +- Create in-memory ontology storage and access patterns +- Implement ontology refresh mechanism + +**Success Criteria**: +- Successfully load and parse ontology configurations +- Validate ontology structure and consistency +- Handle multiple concurrent ontologies + +#### Step 1.2: Vector Store Implementation +- Implement simple NumPy-based vector store as initial prototype +- Add FAISS vector store implementation +- Create vector store interface abstraction +- Implement similarity search with configurable thresholds + +**Success Criteria**: +- Store and retrieve embeddings efficiently +- Perform similarity search with <100ms latency +- Support both NumPy and FAISS backends + +#### Step 1.3: Ontology Embedding Pipeline +- Integrate with embedding service +- Implement `OntologyEmbedder` component +- Generate embeddings for all ontology elements +- Store embeddings with metadata in vector store + +**Success Criteria**: +- Generate embeddings for classes and properties +- Store embeddings with proper metadata +- Rebuild embeddings on ontology updates + +#### Step 1.4: Text Processing Components +- Implement sentence splitter using NLTK/spaCy +- Extract phrases and named entities +- Create text segment hierarchy +- Generate embeddings for text segments + +**Success Criteria**: +- Accurately split text into sentences +- Extract meaningful phrases +- Maintain context relationships + +#### Step 1.5: Ontology Selection Algorithm +- Implement similarity matching between text and ontology +- Build dependency resolution for ontology elements +- Create minimal coherent ontology subsets +- Optimize subset generation performance + +**Success Criteria**: +- Select relevant ontology elements with >80% precision +- Include all necessary dependencies +- Generate subsets in <500ms + +#### Step 1.6: Basic Extraction Service +- Implement prompt construction for extraction +- Integrate with prompt service +- Parse and validate triple responses +- Create `kg-extract-ontology` service endpoint + +**Success Criteria**: +- Extract ontology-conformant triples +- Validate all triples against ontology +- Handle extraction errors gracefully + +### Phase 2: Query System Implementation + +**Goal**: Add ontology-aware query capabilities with support for multiple backends. + +#### Step 2.1: Query Foundation Components +- Implement question analyzer +- Create ontology matcher for queries +- Adapt vector search for query context +- Build backend router component + +**Success Criteria**: +- Analyze questions into semantic components +- Match questions to relevant ontology elements +- Route queries to appropriate backend + +#### Step 2.2: SPARQL Path Implementation +- Implement `onto-query-sparql` service +- Create SPARQL query generator using LLM +- Develop prompt templates for SPARQL generation +- Validate generated SPARQL syntax + +**Success Criteria**: +- Generate valid SPARQL queries +- Use appropriate SPARQL patterns +- Handle complex query types + +#### Step 2.3: SPARQL-Cassandra Engine +- Implement rdflib Store interface for Cassandra +- Create CQL query translator +- Optimize triple pattern matching +- Handle SPARQL result formatting + +**Success Criteria**: +- Execute SPARQL queries on Cassandra +- Support common SPARQL patterns +- Return results in standard format + +#### Step 2.4: Cypher Path Implementation +- Implement `onto-query-cypher` service +- Create Cypher query generator using LLM +- Develop prompt templates for Cypher generation +- Validate generated Cypher syntax + +**Success Criteria**: +- Generate valid Cypher queries +- Use appropriate graph patterns +- Support Neo4j, Memgraph, FalkorDB + +#### Step 2.5: Cypher Executor +- Implement multi-database Cypher executor +- Support Bolt protocol (Neo4j/Memgraph) +- Support Redis protocol (FalkorDB) +- Handle result normalization + +**Success Criteria**: +- Execute Cypher on all target databases +- Handle database-specific differences +- Maintain connection pools efficiently + +#### Step 2.6: Answer Generation +- Implement answer generator component +- Create prompts for answer synthesis +- Format query results for LLM consumption +- Generate natural language answers + +**Success Criteria**: +- Generate accurate answers from query results +- Maintain context from original question +- Provide clear, concise responses + +### Phase 3: Optimization and Robustness + +**Goal**: Optimize performance, add caching, improve error handling, and enhance reliability. + +#### Step 3.1: Performance Optimization +- Implement embedding caching +- Add query result caching +- Optimize vector search with FAISS IVF indexes +- Implement batch processing for embeddings + +**Success Criteria**: +- Reduce average query latency by 50% +- Support 10x more concurrent requests +- Maintain sub-second response times + +#### Step 3.2: Advanced Error Handling +- Implement comprehensive error recovery +- Add fallback mechanisms between query paths +- Create retry logic with exponential backoff +- Improve error logging and diagnostics + +**Success Criteria**: +- Gracefully handle all failure scenarios +- Automatic failover between backends +- Detailed error reporting for debugging + +#### Step 3.3: Monitoring and Observability +- Add performance metrics collection +- Implement query tracing +- Create health check endpoints +- Add resource usage monitoring + +**Success Criteria**: +- Track all key performance indicators +- Identify bottlenecks quickly +- Monitor system health in real-time + +#### Step 3.4: Configuration Management +- Implement dynamic configuration updates +- Add configuration validation +- Create configuration templates +- Support environment-specific settings + +**Success Criteria**: +- Update configuration without restart +- Validate all configuration changes +- Support multiple deployment environments + +### Phase 4: Advanced Features + +**Goal**: Add sophisticated capabilities for production deployment and enhanced functionality. + +#### Step 4.1: Multi-Ontology Support +- Implement ontology selection logic +- Support cross-ontology queries +- Handle ontology versioning +- Create ontology merge capabilities + +**Success Criteria**: +- Query across multiple ontologies +- Handle ontology conflicts +- Support ontology evolution + +#### Step 4.2: Intelligent Query Routing +- Implement performance-based routing +- Add query complexity analysis +- Create adaptive routing algorithms +- Support A/B testing for paths + +**Success Criteria**: +- Route queries optimally +- Learn from query performance +- Improve routing over time + +#### Step 4.3: Advanced Extraction Features +- Add confidence scoring for triples +- Implement explanation generation +- Create feedback loops for improvement +- Support incremental learning + +**Success Criteria**: +- Provide confidence scores +- Explain extraction decisions +- Continuously improve accuracy + +#### Step 4.4: Production Hardening +- Add rate limiting +- Implement authentication/authorization +- Create deployment automation +- Add backup and recovery + +**Success Criteria**: +- Production-ready security +- Automated deployment pipeline +- Disaster recovery capability + +### Delivery Milestones + +1. **Milestone 1** (End of Phase 1): Basic ontology-driven extraction operational +2. **Milestone 2** (End of Phase 2): Full query system with both SPARQL and Cypher paths +3. **Milestone 3** (End of Phase 3): Optimized, robust system ready for staging +4. **Milestone 4** (End of Phase 4): Production-ready system with advanced features + +### Risk Mitigation + +#### Technical Risks +- **Vector Store Scalability**: Start with NumPy, migrate to FAISS gradually +- **Query Generation Accuracy**: Implement validation and fallback mechanisms +- **Backend Compatibility**: Test extensively with each database type +- **Performance Bottlenecks**: Profile early and often, optimize iteratively + +#### Operational Risks +- **Ontology Quality**: Implement validation and consistency checking +- **Service Dependencies**: Add circuit breakers and fallbacks +- **Resource Constraints**: Monitor and set appropriate limits +- **Data Consistency**: Implement proper transaction handling + +### Success Metrics + +#### Phase 1 Success Metrics +- Extraction accuracy: >90% ontology conformance +- Processing speed: <1 second per chunk +- Ontology load time: <10 seconds +- Vector search latency: <100ms + +#### Phase 2 Success Metrics +- Query success rate: >95% +- Query latency: <2 seconds end-to-end +- Backend compatibility: 100% for target databases +- Answer accuracy: >85% based on available data + +#### Phase 3 Success Metrics +- System uptime: >99.9% +- Error recovery rate: >95% +- Cache hit rate: >60% +- Concurrent users: >100 + +#### Phase 4 Success Metrics +- Multi-ontology queries: Fully supported +- Routing optimization: 30% latency reduction +- Confidence scoring accuracy: >90% +- Production deployment: Zero-downtime updates + +## References + +- [OWL 2 Web Ontology Language](https://www.w3.org/TR/owl2-overview/) +- [GraphRAG Architecture](https://github.com/microsoft/graphrag) +- [Sentence Transformers](https://www.sbert.net/) +- [FAISS Vector Search](https://github.com/facebookresearch/faiss) +- [spaCy NLP Library](https://spacy.io/) +- [rdflib Documentation](https://rdflib.readthedocs.io/) +- [Neo4j Bolt Protocol](https://neo4j.com/docs/bolt/current/) diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/__init__.py b/trustgraph-flow/trustgraph/extract/kg/ontology/__init__.py new file mode 100644 index 00000000..102255a1 --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/__init__.py @@ -0,0 +1 @@ +from . extract import * \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py new file mode 100644 index 00000000..b0942dc2 --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -0,0 +1,473 @@ +""" +OntoRAG: Ontology-based knowledge extraction service. +Extracts ontology-conformant triples from text chunks. +""" + +import json +import logging +import asyncio +from typing import List, Dict, Any, Optional + +from .... schema import Chunk, Triple, Triples, Metadata, Value +from .... schema import PromptRequest, PromptResponse +from .... rdf import TRUSTGRAPH_ENTITIES, RDF_TYPE, RDF_LABEL +from .... base import FlowProcessor, ConsumerSpec, ProducerSpec +from .... base import PromptClientSpec +from .... tables.config import ConfigTableStore + +from .ontology_loader import OntologyLoader +from .ontology_embedder import OntologyEmbedder +from .vector_store import InMemoryVectorStore +from .text_processor import TextProcessor +from .ontology_selector import OntologySelector, OntologySubset + +logger = logging.getLogger(__name__) + +default_ident = "kg-extract-ontology" +default_concurrency = 1 + + +class Processor(FlowProcessor): + """Main OntoRAG extraction processor.""" + + def __init__(self, **params): + id = params.get("id", default_ident) + concurrency = params.get("concurrency", default_concurrency) + + super(Processor, self).__init__( + **params | { + "id": id, + "concurrency": concurrency, + } + ) + + # Register specifications + self.register_specification( + ConsumerSpec( + name="input", + schema=Chunk, + handler=self.on_message, + concurrency=concurrency, + ) + ) + + self.register_specification( + PromptClientSpec( + request_name="prompt-request", + response_name="prompt-response", + ) + ) + + self.register_specification( + ProducerSpec( + name="triples", + schema=Triples + ) + ) + + # Initialize components + self.ontology_loader = None + self.ontology_embedder = None + self.text_processor = TextProcessor() + self.ontology_selector = None + self.initialized = False + + # Configuration + self.top_k = params.get("top_k", 10) + self.similarity_threshold = params.get("similarity_threshold", 0.7) + self.refresh_interval = params.get("ontology_refresh_interval", 300) + + # Cassandra configuration for config store + self.cassandra_host = params.get("cassandra_host", "localhost") + self.cassandra_username = params.get("cassandra_username", "cassandra") + self.cassandra_password = params.get("cassandra_password", "cassandra") + self.cassandra_keyspace = params.get("cassandra_keyspace", "trustgraph") + + async def initialize_components(self, flow): + """Initialize OntoRAG components.""" + if self.initialized: + return + + try: + # Create configuration store + config_store = ConfigTableStore( + self.cassandra_host, + self.cassandra_username, + self.cassandra_password, + self.cassandra_keyspace + ) + + # Initialize ontology loader + self.ontology_loader = OntologyLoader(config_store) + ontologies = await self.ontology_loader.load_ontologies() + logger.info(f"Loaded {len(ontologies)} ontologies") + + # Initialize vector store + vector_store = InMemoryVectorStore.create( + dimension=1536, # text-embedding-3-small + prefer_faiss=True, + index_type='flat' + ) + + # Initialize ontology embedder with embedding service wrapper + embedding_service = EmbeddingServiceWrapper(flow) + self.ontology_embedder = OntologyEmbedder( + embedding_service=embedding_service, + vector_store=vector_store + ) + + # Embed all ontologies + if ontologies: + await self.ontology_embedder.embed_ontologies(ontologies) + logger.info(f"Embedded {self.ontology_embedder.get_embedded_count()} ontology elements") + + # Initialize ontology selector + self.ontology_selector = OntologySelector( + ontology_embedder=self.ontology_embedder, + ontology_loader=self.ontology_loader, + top_k=self.top_k, + similarity_threshold=self.similarity_threshold + ) + + self.initialized = True + logger.info("OntoRAG components initialized successfully") + + # Schedule periodic refresh + asyncio.create_task(self.refresh_ontologies_periodically()) + + except Exception as e: + logger.error(f"Failed to initialize OntoRAG components: {e}", exc_info=True) + raise + + async def refresh_ontologies_periodically(self): + """Periodically refresh ontologies from configuration.""" + while True: + await asyncio.sleep(self.refresh_interval) + try: + logger.info("Refreshing ontologies...") + ontologies = await self.ontology_loader.refresh_ontologies() + if ontologies: + # Re-embed new ontologies + for ont_id in ontologies: + if not self.ontology_embedder.is_ontology_embedded(ont_id): + await self.ontology_embedder.embed_ontology(ontologies[ont_id]) + logger.info("Ontology refresh complete") + except Exception as e: + logger.error(f"Error refreshing ontologies: {e}", exc_info=True) + + async def on_message(self, msg, consumer, flow): + """Process incoming chunk message.""" + v = msg.value() + logger.info(f"Extracting ontology-based triples from {v.metadata.id}...") + + # Initialize components if needed + if not self.initialized: + await self.initialize_components(flow) + + chunk = v.chunk.decode("utf-8") + logger.debug(f"Processing chunk: {chunk[:200]}...") + + try: + # Process text into segments + segments = self.text_processor.process_chunk(chunk, extract_phrases=True) + logger.debug(f"Split chunk into {len(segments)} segments") + + # Select relevant ontology subset + ontology_subsets = await self.ontology_selector.select_ontology_subset(segments) + + if not ontology_subsets: + logger.warning("No relevant ontology elements found for chunk") + # Emit empty triples + await self.emit_triples( + flow("triples"), + v.metadata, + [] + ) + return + + # Merge subsets if multiple ontologies matched + if len(ontology_subsets) > 1: + ontology_subset = self.ontology_selector.merge_subsets(ontology_subsets) + else: + ontology_subset = ontology_subsets[0] + + logger.debug(f"Selected ontology subset with {len(ontology_subset.classes)} classes, " + f"{len(ontology_subset.object_properties)} object properties, " + f"{len(ontology_subset.datatype_properties)} datatype properties") + + # Build extraction prompt + prompt = self.build_extraction_prompt(chunk, ontology_subset) + + # Call prompt service for extraction + try: + triples_response = await flow("prompt-request").extract_ontology_triples( + prompt=prompt + ) + logger.debug(f"Extraction response: {triples_response}") + + if not isinstance(triples_response, list): + logger.error("Expected list of triples from prompt service") + triples_response = [] + + except Exception as e: + logger.error(f"Prompt service error: {e}", exc_info=True) + triples_response = [] + + # Parse and validate triples + triples = self.parse_and_validate_triples(triples_response, ontology_subset) + + # Add metadata triples + for t in v.metadata.metadata: + triples.append(t) + + # Emit triples + await self.emit_triples( + flow("triples"), + v.metadata, + triples + ) + + logger.info(f"Extracted {len(triples)} ontology-conformant triples") + + except Exception as e: + logger.error(f"OntoRAG extraction exception: {e}", exc_info=True) + # Emit empty triples on error + await self.emit_triples( + flow("triples"), + v.metadata, + [] + ) + + def build_extraction_prompt(self, chunk: str, ontology_subset: OntologySubset) -> str: + """Build prompt for ontology-based extraction.""" + # Format classes + classes_str = self.format_classes(ontology_subset.classes) + + # Format properties + obj_props_str = self.format_properties( + ontology_subset.object_properties, + "object" + ) + dt_props_str = self.format_properties( + ontology_subset.datatype_properties, + "datatype" + ) + + prompt = f"""Extract knowledge triples from the following text using ONLY the provided ontology elements. + +ONTOLOGY CLASSES: +{classes_str} + +OBJECT PROPERTIES (connect entities): +{obj_props_str} + +DATATYPE PROPERTIES (entity attributes): +{dt_props_str} + +RULES: +1. Only use classes defined above for entity types +2. Only use properties defined above for relationships and attributes +3. Respect domain and range constraints +4. Output format: JSON array of {{"subject": "", "predicate": "", "object": ""}} +5. For class instances, use rdf:type as predicate +6. Include rdfs:label for new entities + +TEXT: +{chunk} + +TRIPLES (JSON array):""" + + return prompt + + def format_classes(self, classes: Dict[str, Any]) -> str: + """Format classes for prompt.""" + if not classes: + return "None" + + lines = [] + for class_id, definition in classes.items(): + comment = definition.get('comment', '') + parent = definition.get('subclass_of', 'Thing') + lines.append(f"- {class_id} (subclass of {parent}): {comment}") + + return '\n'.join(lines) + + def format_properties(self, properties: Dict[str, Any], prop_type: str) -> str: + """Format properties for prompt.""" + if not properties: + return "None" + + lines = [] + for prop_id, definition in properties.items(): + comment = definition.get('comment', '') + domain = definition.get('domain', 'Any') + range_val = definition.get('range', 'Any') + lines.append(f"- {prop_id} ({domain} -> {range_val}): {comment}") + + return '\n'.join(lines) + + def parse_and_validate_triples(self, triples_response: List[Any], + ontology_subset: OntologySubset) -> List[Triple]: + """Parse and validate extracted triples against ontology.""" + validated_triples = [] + + for triple_data in triples_response: + try: + if isinstance(triple_data, dict): + subject = triple_data.get('subject', '') + predicate = triple_data.get('predicate', '') + object_val = triple_data.get('object', '') + + if not subject or not predicate or not object_val: + continue + + # Validate against ontology + if self.is_valid_triple(subject, predicate, object_val, ontology_subset): + # Create Triple object + s_value = Value(value=subject, is_uri=self.is_uri(subject)) + p_value = Value(value=predicate, is_uri=True) + o_value = Value(value=object_val, is_uri=self.is_uri(object_val)) + + validated_triples.append(Triple( + s=s_value, + p=p_value, + o=o_value + )) + else: + logger.debug(f"Invalid triple: ({subject}, {predicate}, {object_val})") + + except Exception as e: + logger.error(f"Error parsing triple: {e}") + + return validated_triples + + def is_valid_triple(self, subject: str, predicate: str, object_val: str, + ontology_subset: OntologySubset) -> bool: + """Validate triple against ontology constraints.""" + # Special case for rdf:type + if predicate == "rdf:type" or predicate == str(RDF_TYPE): + # Check if object is a valid class + return object_val in ontology_subset.classes + + # Special case for rdfs:label + if predicate == "rdfs:label" or predicate == str(RDF_LABEL): + return True # Labels are always valid + + # Check if predicate is a valid property + is_obj_prop = predicate in ontology_subset.object_properties + is_dt_prop = predicate in ontology_subset.datatype_properties + + if not is_obj_prop and not is_dt_prop: + return False # Unknown property + + # TODO: Add more sophisticated validation (domain/range checking) + return True + + def is_uri(self, value: str) -> bool: + """Check if value is a URI.""" + return value.startswith("http://") or value.startswith("https://") or \ + value.startswith(str(TRUSTGRAPH_ENTITIES)) or \ + value in ["rdf:type", "rdfs:label"] + + async def emit_triples(self, pub, metadata: Metadata, triples: List[Triple]): + """Emit triples to output.""" + t = Triples( + metadata=Metadata( + id=metadata.id, + metadata=[], + user=metadata.user, + collection=metadata.collection, + ), + triples=triples, + ) + await pub.send(t) + + @staticmethod + def add_args(parser): + """Add command-line arguments.""" + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Concurrent processing threads (default: {default_concurrency})' + ) + parser.add_argument( + '--top-k', + type=int, + default=10, + help='Number of top ontology elements to retrieve (default: 10)' + ) + parser.add_argument( + '--similarity-threshold', + type=float, + default=0.7, + help='Similarity threshold for ontology matching (default: 0.7)' + ) + parser.add_argument( + '--ontology-refresh-interval', + type=int, + default=300, + help='Ontology refresh interval in seconds (default: 300)' + ) + parser.add_argument( + '--cassandra-host', + type=str, + default='localhost', + help='Cassandra host (default: localhost)' + ) + parser.add_argument( + '--cassandra-username', + type=str, + default='cassandra', + help='Cassandra username (default: cassandra)' + ) + parser.add_argument( + '--cassandra-password', + type=str, + default='cassandra', + help='Cassandra password (default: cassandra)' + ) + parser.add_argument( + '--cassandra-keyspace', + type=str, + default='trustgraph', + help='Cassandra keyspace (default: trustgraph)' + ) + FlowProcessor.add_args(parser) + + +class EmbeddingServiceWrapper: + """Wrapper to adapt flow prompt service to embedding service interface.""" + + def __init__(self, flow): + self.flow = flow + + async def embed(self, text: str): + """Generate embedding for single text.""" + try: + response = await self.flow("prompt-request").get_embedding(text=text) + return response + except Exception as e: + logger.error(f"Embedding service error: {e}") + return None + + async def embed_batch(self, texts: List[str]): + """Generate embeddings for multiple texts.""" + try: + # Process in parallel for better performance + tasks = [self.embed(text) for text in texts] + embeddings = await asyncio.gather(*tasks) + # Filter out None values and convert to array + import numpy as np + valid_embeddings = [e for e in embeddings if e is not None] + if valid_embeddings: + return np.array(valid_embeddings) + return None + except Exception as e: + logger.error(f"Batch embedding service error: {e}") + return None + + +def run(): + """Launch the OntoRAG extraction service.""" + Processor.launch(default_ident, __doc__) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py new file mode 100644 index 00000000..402f3b7a --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py @@ -0,0 +1,276 @@ +""" +Ontology embedder component for OntoRAG system. +Generates and stores embeddings for ontology elements. +""" + +import logging +import numpy as np +from typing import Dict, List, Any, Optional +from dataclasses import dataclass + +from .ontology_loader import Ontology, OntologyClass, OntologyProperty +from .vector_store import VectorStore, InMemoryVectorStore + +logger = logging.getLogger(__name__) + + +@dataclass +class OntologyElementMetadata: + """Metadata for an embedded ontology element.""" + type: str # 'class', 'objectProperty', 'datatypeProperty' + ontology: str # Ontology ID + element: str # Element ID + definition: Dict[str, Any] # Full element definition + text: str # Text used for embedding + + +class OntologyEmbedder: + """Generates embeddings for ontology elements and stores them in vector store.""" + + def __init__(self, embedding_service=None, vector_store: Optional[VectorStore] = None): + """Initialize the ontology embedder. + + Args: + embedding_service: Service for generating embeddings + vector_store: Vector store instance (defaults to InMemoryVectorStore) + """ + self.embedding_service = embedding_service + self.vector_store = vector_store or InMemoryVectorStore.create() + self.embedded_ontologies = set() + + def _create_text_representation(self, element_id: str, element: Any, + element_type: str) -> str: + """Create text representation of an ontology element for embedding. + + Args: + element_id: ID of the element + element: The element object (OntologyClass or OntologyProperty) + element_type: Type of element + + Returns: + Text representation for embedding + """ + parts = [] + + # Add the element ID (often meaningful) + parts.append(element_id.replace('-', ' ').replace('_', ' ')) + + # Add labels + if hasattr(element, 'labels') and element.labels: + for label in element.labels: + if isinstance(label, dict): + parts.append(label.get('value', '')) + else: + parts.append(str(label)) + + # Add comment/description + if hasattr(element, 'comment') and element.comment: + parts.append(element.comment) + + # Add type-specific information + if element_type == 'class': + if hasattr(element, 'subclass_of') and element.subclass_of: + parts.append(f"subclass of {element.subclass_of}") + elif element_type in ['objectProperty', 'datatypeProperty']: + if hasattr(element, 'domain') and element.domain: + parts.append(f"domain: {element.domain}") + if hasattr(element, 'range') and element.range: + parts.append(f"range: {element.range}") + + # Join all parts with spaces + text = ' '.join(filter(None, parts)) + return text + + async def embed_ontology(self, ontology: Ontology) -> int: + """Generate and store embeddings for all elements in an ontology. + + Args: + ontology: The ontology to embed + + Returns: + Number of elements embedded + """ + if not self.embedding_service: + logger.warning("No embedding service available, skipping embedding") + return 0 + + embedded_count = 0 + batch_size = 50 # Process embeddings in batches + + # Collect all elements to embed + elements_to_embed = [] + + # Process classes + for class_id, class_def in ontology.classes.items(): + text = self._create_text_representation(class_id, class_def, 'class') + elements_to_embed.append({ + 'id': f"{ontology.id}:class:{class_id}", + 'text': text, + 'metadata': OntologyElementMetadata( + type='class', + ontology=ontology.id, + element=class_id, + definition=class_def.__dict__, + text=text + ).__dict__ + }) + + # Process object properties + for prop_id, prop_def in ontology.object_properties.items(): + text = self._create_text_representation(prop_id, prop_def, 'objectProperty') + elements_to_embed.append({ + 'id': f"{ontology.id}:objectProperty:{prop_id}", + 'text': text, + 'metadata': OntologyElementMetadata( + type='objectProperty', + ontology=ontology.id, + element=prop_id, + definition=prop_def.__dict__, + text=text + ).__dict__ + }) + + # Process datatype properties + for prop_id, prop_def in ontology.datatype_properties.items(): + text = self._create_text_representation(prop_id, prop_def, 'datatypeProperty') + elements_to_embed.append({ + 'id': f"{ontology.id}:datatypeProperty:{prop_id}", + 'text': text, + 'metadata': OntologyElementMetadata( + type='datatypeProperty', + ontology=ontology.id, + element=prop_id, + definition=prop_def.__dict__, + text=text + ).__dict__ + }) + + # Process in batches + for i in range(0, len(elements_to_embed), batch_size): + batch = elements_to_embed[i:i + batch_size] + + # Get embeddings for batch + texts = [elem['text'] for elem in batch] + try: + # Call embedding service (async) + embeddings = await self.embedding_service.embed_batch(texts) + + # Store in vector store + ids = [elem['id'] for elem in batch] + metadata_list = [elem['metadata'] for elem in batch] + + self.vector_store.add_batch(ids, embeddings, metadata_list) + embedded_count += len(batch) + + logger.debug(f"Embedded batch of {len(batch)} elements from ontology {ontology.id}") + + except Exception as e: + logger.error(f"Failed to embed batch for ontology {ontology.id}: {e}") + + self.embedded_ontologies.add(ontology.id) + logger.info(f"Embedded {embedded_count} elements from ontology {ontology.id}") + return embedded_count + + async def embed_ontologies(self, ontologies: Dict[str, Ontology]) -> int: + """Generate and store embeddings for multiple ontologies. + + Args: + ontologies: Dictionary of ontology ID to Ontology objects + + Returns: + Total number of elements embedded + """ + total_embedded = 0 + + for ont_id, ontology in ontologies.items(): + if ont_id not in self.embedded_ontologies: + count = await self.embed_ontology(ontology) + total_embedded += count + else: + logger.debug(f"Ontology {ont_id} already embedded, skipping") + + logger.info(f"Total embedded elements: {total_embedded} from {len(ontologies)} ontologies") + return total_embedded + + async def embed_text(self, text: str) -> Optional[np.ndarray]: + """Generate embedding for a single text. + + Args: + text: Text to embed + + Returns: + Embedding vector or None if failed + """ + if not self.embedding_service: + logger.warning("No embedding service available") + return None + + try: + embedding = await self.embedding_service.embed(text) + return embedding + except Exception as e: + logger.error(f"Failed to embed text: {e}") + return None + + async def embed_texts(self, texts: List[str]) -> Optional[np.ndarray]: + """Generate embeddings for multiple texts. + + Args: + texts: List of texts to embed + + Returns: + Array of embeddings or None if failed + """ + if not self.embedding_service: + logger.warning("No embedding service available") + return None + + try: + embeddings = await self.embedding_service.embed_batch(texts) + return embeddings + except Exception as e: + logger.error(f"Failed to embed texts: {e}") + return None + + def clear_embeddings(self, ontology_id: Optional[str] = None): + """Clear embeddings from vector store. + + Args: + ontology_id: If provided, only clear embeddings for this ontology + Otherwise, clear all embeddings + """ + if ontology_id: + # Would need to implement selective clearing in vector store + # For now, log warning + logger.warning(f"Selective clearing not implemented, would clear {ontology_id}") + else: + self.vector_store.clear() + self.embedded_ontologies.clear() + logger.info("Cleared all embeddings from vector store") + + def get_vector_store(self) -> VectorStore: + """Get the vector store instance. + + Returns: + The vector store being used + """ + return self.vector_store + + def get_embedded_count(self) -> int: + """Get the number of embedded elements. + + Returns: + Number of elements in the vector store + """ + return self.vector_store.size() + + def is_ontology_embedded(self, ontology_id: str) -> bool: + """Check if an ontology has been embedded. + + Args: + ontology_id: ID of the ontology + + Returns: + True if the ontology has been embedded + """ + return ontology_id in self.embedded_ontologies \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_loader.py b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_loader.py new file mode 100644 index 00000000..2dc53003 --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_loader.py @@ -0,0 +1,262 @@ +""" +Ontology loader component for OntoRAG system. +Loads and manages ontologies from configuration service. +""" + +import json +import logging +from typing import Dict, Any, Optional, List +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + + +@dataclass +class OntologyClass: + """Represents an OWL-like class in the ontology.""" + uri: str + type: str = "owl:Class" + labels: List[Dict[str, str]] = field(default_factory=list) + comment: Optional[str] = None + subclass_of: Optional[str] = None + equivalent_classes: List[str] = field(default_factory=list) + disjoint_with: List[str] = field(default_factory=list) + identifier: Optional[str] = None + + @staticmethod + def from_dict(class_id: str, data: Dict[str, Any]) -> 'OntologyClass': + """Create OntologyClass from dictionary representation.""" + labels = data.get('rdfs:label', []) + if isinstance(labels, list): + labels = labels + else: + labels = [labels] if labels else [] + + return OntologyClass( + uri=data.get('uri', ''), + type=data.get('type', 'owl:Class'), + labels=labels, + comment=data.get('rdfs:comment'), + subclass_of=data.get('rdfs:subClassOf'), + equivalent_classes=data.get('owl:equivalentClass', []), + disjoint_with=data.get('owl:disjointWith', []), + identifier=data.get('dcterms:identifier') + ) + + +@dataclass +class OntologyProperty: + """Represents a property (object or datatype) in the ontology.""" + uri: str + type: str + labels: List[Dict[str, str]] = field(default_factory=list) + comment: Optional[str] = None + domain: Optional[str] = None + range: Optional[str] = None + inverse_of: Optional[str] = None + functional: bool = False + inverse_functional: bool = False + min_cardinality: Optional[int] = None + max_cardinality: Optional[int] = None + cardinality: Optional[int] = None + + @staticmethod + def from_dict(prop_id: str, data: Dict[str, Any]) -> 'OntologyProperty': + """Create OntologyProperty from dictionary representation.""" + labels = data.get('rdfs:label', []) + if isinstance(labels, list): + labels = labels + else: + labels = [labels] if labels else [] + + return OntologyProperty( + uri=data.get('uri', ''), + type=data.get('type', ''), + labels=labels, + comment=data.get('rdfs:comment'), + domain=data.get('rdfs:domain'), + range=data.get('rdfs:range'), + inverse_of=data.get('owl:inverseOf'), + functional=data.get('owl:functionalProperty', False), + inverse_functional=data.get('owl:inverseFunctionalProperty', False), + min_cardinality=data.get('owl:minCardinality'), + max_cardinality=data.get('owl:maxCardinality'), + cardinality=data.get('owl:cardinality') + ) + + +@dataclass +class Ontology: + """Represents a complete ontology with metadata, classes, and properties.""" + id: str + metadata: Dict[str, Any] + classes: Dict[str, OntologyClass] + object_properties: Dict[str, OntologyProperty] + datatype_properties: Dict[str, OntologyProperty] + + def get_class(self, class_id: str) -> Optional[OntologyClass]: + """Get a class by ID.""" + return self.classes.get(class_id) + + def get_property(self, prop_id: str) -> Optional[OntologyProperty]: + """Get a property (object or datatype) by ID.""" + prop = self.object_properties.get(prop_id) + if prop is None: + prop = self.datatype_properties.get(prop_id) + return prop + + def get_parent_classes(self, class_id: str) -> List[str]: + """Get all parent classes (following subClassOf hierarchy).""" + parents = [] + current = class_id + visited = set() + + while current and current not in visited: + visited.add(current) + cls = self.get_class(current) + if cls and cls.subclass_of: + parents.append(cls.subclass_of) + current = cls.subclass_of + else: + break + + return parents + + def validate_structure(self) -> List[str]: + """Validate ontology structure and return list of issues.""" + issues = [] + + # Check for circular inheritance + for class_id in self.classes: + visited = set() + current = class_id + while current: + if current in visited: + issues.append(f"Circular inheritance detected for class {class_id}") + break + visited.add(current) + cls = self.get_class(current) + if cls: + current = cls.subclass_of + else: + break + + # Check property domains and ranges exist + for prop_id, prop in {**self.object_properties, **self.datatype_properties}.items(): + if prop.domain and prop.domain not in self.classes: + issues.append(f"Property {prop_id} has unknown domain {prop.domain}") + if prop.type == "owl:ObjectProperty" and prop.range and prop.range not in self.classes: + issues.append(f"Object property {prop_id} has unknown range class {prop.range}") + + # Check disjoint classes + for class_id, cls in self.classes.items(): + for disjoint_id in cls.disjoint_with: + if disjoint_id not in self.classes: + issues.append(f"Class {class_id} disjoint with unknown class {disjoint_id}") + + return issues + + +class OntologyLoader: + """Loads and manages ontologies from configuration service.""" + + def __init__(self, config_store=None): + """Initialize the ontology loader. + + Args: + config_store: Configuration store instance (injected dependency) + """ + self.config_store = config_store + self.ontologies: Dict[str, Ontology] = {} + self.refresh_interval = 300 # Default 5 minutes + + async def load_ontologies(self) -> Dict[str, Ontology]: + """Load all ontologies from configuration service. + + Returns: + Dictionary of ontology ID to Ontology objects + """ + if not self.config_store: + logger.warning("No configuration store available, returning empty ontologies") + return {} + + try: + # Get all ontology configurations + ontology_configs = await self.config_store.get("ontology").values() + + for ont_id, ont_data in ontology_configs.items(): + try: + # Parse JSON if string + if isinstance(ont_data, str): + ont_data = json.loads(ont_data) + + # Parse classes + classes = {} + for class_id, class_data in ont_data.get('classes', {}).items(): + classes[class_id] = OntologyClass.from_dict(class_id, class_data) + + # Parse object properties + object_props = {} + for prop_id, prop_data in ont_data.get('objectProperties', {}).items(): + object_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data) + + # Parse datatype properties + datatype_props = {} + for prop_id, prop_data in ont_data.get('datatypeProperties', {}).items(): + datatype_props[prop_id] = OntologyProperty.from_dict(prop_id, prop_data) + + # Create ontology + ontology = Ontology( + id=ont_id, + metadata=ont_data.get('metadata', {}), + classes=classes, + object_properties=object_props, + datatype_properties=datatype_props + ) + + # Validate structure + issues = ontology.validate_structure() + if issues: + logger.warning(f"Ontology {ont_id} has validation issues: {issues}") + + self.ontologies[ont_id] = ontology + logger.info(f"Loaded ontology {ont_id} with {len(classes)} classes, " + f"{len(object_props)} object properties, " + f"{len(datatype_props)} datatype properties") + + except Exception as e: + logger.error(f"Failed to load ontology {ont_id}: {e}", exc_info=True) + + except Exception as e: + logger.error(f"Failed to load ontologies from config: {e}", exc_info=True) + + return self.ontologies + + async def refresh_ontologies(self): + """Refresh ontologies from configuration service.""" + logger.info("Refreshing ontologies...") + return await self.load_ontologies() + + def get_ontology(self, ont_id: str) -> Optional[Ontology]: + """Get a specific ontology by ID. + + Args: + ont_id: Ontology identifier + + Returns: + Ontology object or None if not found + """ + return self.ontologies.get(ont_id) + + def get_all_ontologies(self) -> Dict[str, Ontology]: + """Get all loaded ontologies. + + Returns: + Dictionary of ontology ID to Ontology objects + """ + return self.ontologies + + def clear(self): + """Clear all loaded ontologies.""" + self.ontologies.clear() + logger.info("Cleared all loaded ontologies") \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py new file mode 100644 index 00000000..4389f0bc --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_selector.py @@ -0,0 +1,297 @@ +""" +Ontology selection algorithm for OntoRAG system. +Selects relevant ontology subsets based on text similarity. +""" + +import logging +from typing import List, Dict, Any, Set, Optional, Tuple +from dataclasses import dataclass +from collections import defaultdict + +from .ontology_loader import Ontology, OntologyLoader +from .ontology_embedder import OntologyEmbedder +from .text_processor import TextSegment +from .vector_store import SearchResult + +logger = logging.getLogger(__name__) + + +@dataclass +class OntologySubset: + """Represents a subset of an ontology relevant to a text chunk.""" + ontology_id: str + classes: Dict[str, Any] + object_properties: Dict[str, Any] + datatype_properties: Dict[str, Any] + metadata: Dict[str, Any] + relevance_score: float = 0.0 + + +class OntologySelector: + """Selects relevant ontology elements for text segments using vector similarity.""" + + def __init__(self, ontology_embedder: OntologyEmbedder, + ontology_loader: OntologyLoader, + top_k: int = 10, + similarity_threshold: float = 0.7): + """Initialize the ontology selector. + + Args: + ontology_embedder: Embedder with vector store + ontology_loader: Loader with ontology definitions + top_k: Number of top results to retrieve per segment + similarity_threshold: Minimum similarity score + """ + self.embedder = ontology_embedder + self.loader = ontology_loader + self.top_k = top_k + self.similarity_threshold = similarity_threshold + + async def select_ontology_subset(self, segments: List[TextSegment]) -> List[OntologySubset]: + """Select relevant ontology subsets for text segments. + + Args: + segments: List of text segments to match + + Returns: + List of ontology subsets with relevant elements + """ + # Collect all relevant elements + relevant_elements = await self._find_relevant_elements(segments) + + # Group by ontology and build subsets + ontology_subsets = self._build_ontology_subsets(relevant_elements) + + # Resolve dependencies + for subset in ontology_subsets: + self._resolve_dependencies(subset) + + logger.info(f"Selected {len(ontology_subsets)} ontology subsets") + return ontology_subsets + + async def _find_relevant_elements(self, segments: List[TextSegment]) -> Set[Tuple[str, str, str, Dict]]: + """Find relevant ontology elements for text segments. + + Args: + segments: Text segments to match + + Returns: + Set of (ontology_id, element_type, element_id, definition) tuples + """ + relevant_elements = set() + element_scores = defaultdict(float) + + # Process each segment + for segment in segments: + # Get embedding for segment + embedding = await self.embedder.embed_text(segment.text) + if embedding is None: + logger.warning(f"Failed to embed segment: {segment.text[:50]}...") + continue + + # Search vector store + results = self.embedder.get_vector_store().search( + embedding=embedding, + top_k=self.top_k, + threshold=self.similarity_threshold + ) + + # Process results + for result in results: + metadata = result.metadata + element_key = ( + metadata['ontology'], + metadata['type'], + metadata['element'], + str(metadata['definition']) # Convert dict to string for hashability + ) + relevant_elements.add(element_key) + # Track scores for ranking + element_scores[element_key] = max(element_scores[element_key], result.score) + + logger.debug(f"Found {len(relevant_elements)} relevant elements from {len(segments)} segments") + return relevant_elements + + def _build_ontology_subsets(self, relevant_elements: Set[Tuple[str, str, str, Dict]]) -> List[OntologySubset]: + """Build ontology subsets from relevant elements. + + Args: + relevant_elements: Set of relevant element tuples + + Returns: + List of ontology subsets + """ + # Group elements by ontology + ontology_groups = defaultdict(lambda: { + 'classes': {}, + 'object_properties': {}, + 'datatype_properties': {}, + 'scores': [] + }) + + for ont_id, elem_type, elem_id, definition in relevant_elements: + # Parse definition back from string if needed + if isinstance(definition, str): + import json + try: + definition = json.loads(definition.replace("'", '"')) + except: + definition = eval(definition) # Fallback for dict-like strings + + # Get the actual ontology and element + ontology = self.loader.get_ontology(ont_id) + if not ontology: + logger.warning(f"Ontology {ont_id} not found in loader") + continue + + # Add element to appropriate category + if elem_type == 'class': + cls = ontology.get_class(elem_id) + if cls: + ontology_groups[ont_id]['classes'][elem_id] = cls.__dict__ + elif elem_type == 'objectProperty': + prop = ontology.object_properties.get(elem_id) + if prop: + ontology_groups[ont_id]['object_properties'][elem_id] = prop.__dict__ + elif elem_type == 'datatypeProperty': + prop = ontology.datatype_properties.get(elem_id) + if prop: + ontology_groups[ont_id]['datatype_properties'][elem_id] = prop.__dict__ + + # Create OntologySubset objects + subsets = [] + for ont_id, elements in ontology_groups.items(): + ontology = self.loader.get_ontology(ont_id) + if ontology: + subset = OntologySubset( + ontology_id=ont_id, + classes=elements['classes'], + object_properties=elements['object_properties'], + datatype_properties=elements['datatype_properties'], + metadata=ontology.metadata, + relevance_score=sum(elements['scores']) / len(elements['scores']) if elements['scores'] else 0.0 + ) + subsets.append(subset) + + return subsets + + def _resolve_dependencies(self, subset: OntologySubset): + """Resolve dependencies for ontology subset elements. + + Args: + subset: Ontology subset to resolve dependencies for + """ + ontology = self.loader.get_ontology(subset.ontology_id) + if not ontology: + return + + # Track classes to add + classes_to_add = set() + + # Resolve class hierarchies + for class_id in list(subset.classes.keys()): + # Add parent classes + parents = ontology.get_parent_classes(class_id) + for parent_id in parents: + parent_class = ontology.get_class(parent_id) + if parent_class and parent_id not in subset.classes: + classes_to_add.add(parent_id) + + # Resolve property domains and ranges + for prop_id, prop_def in subset.object_properties.items(): + # Add domain class + if 'domain' in prop_def and prop_def['domain']: + domain_id = prop_def['domain'] + if domain_id not in subset.classes: + domain_class = ontology.get_class(domain_id) + if domain_class: + classes_to_add.add(domain_id) + + # Add range class + if 'range' in prop_def and prop_def['range']: + range_id = prop_def['range'] + if range_id not in subset.classes: + range_class = ontology.get_class(range_id) + if range_class: + classes_to_add.add(range_id) + + # Resolve datatype property domains + for prop_id, prop_def in subset.datatype_properties.items(): + if 'domain' in prop_def and prop_def['domain']: + domain_id = prop_def['domain'] + if domain_id not in subset.classes: + domain_class = ontology.get_class(domain_id) + if domain_class: + classes_to_add.add(domain_id) + + # Add inverse properties + for prop_id, prop_def in list(subset.object_properties.items()): + if 'inverse_of' in prop_def and prop_def['inverse_of']: + inverse_id = prop_def['inverse_of'] + if inverse_id not in subset.object_properties: + inverse_prop = ontology.object_properties.get(inverse_id) + if inverse_prop: + subset.object_properties[inverse_id] = inverse_prop.__dict__ + + # Add collected classes + for class_id in classes_to_add: + cls = ontology.get_class(class_id) + if cls: + subset.classes[class_id] = cls.__dict__ + + logger.debug(f"Resolved dependencies for subset {subset.ontology_id}: " + f"added {len(classes_to_add)} classes") + + def merge_subsets(self, subsets: List[OntologySubset]) -> OntologySubset: + """Merge multiple ontology subsets into one. + + Args: + subsets: List of subsets to merge + + Returns: + Merged ontology subset + """ + if not subsets: + return None + if len(subsets) == 1: + return subsets[0] + + # Use first subset as base + merged = OntologySubset( + ontology_id="merged", + classes={}, + object_properties={}, + datatype_properties={}, + metadata={}, + relevance_score=0.0 + ) + + # Merge all subsets + total_score = 0.0 + for subset in subsets: + # Merge classes + for class_id, class_def in subset.classes.items(): + key = f"{subset.ontology_id}:{class_id}" + merged.classes[key] = class_def + + # Merge object properties + for prop_id, prop_def in subset.object_properties.items(): + key = f"{subset.ontology_id}:{prop_id}" + merged.object_properties[key] = prop_def + + # Merge datatype properties + for prop_id, prop_def in subset.datatype_properties.items(): + key = f"{subset.ontology_id}:{prop_id}" + merged.datatype_properties[key] = prop_def + + total_score += subset.relevance_score + + # Average relevance score + merged.relevance_score = total_score / len(subsets) + + logger.info(f"Merged {len(subsets)} subsets into one with " + f"{len(merged.classes)} classes, " + f"{len(merged.object_properties)} object properties, " + f"{len(merged.datatype_properties)} datatype properties") + + return merged \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/run.py b/trustgraph-flow/trustgraph/extract/kg/ontology/run.py new file mode 100644 index 00000000..c0a6143b --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/run.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 + +""" +OntoRAG extraction service launcher. +""" + +from . extract import run + +if __name__ == "__main__": + run() \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/text_processor.py b/trustgraph-flow/trustgraph/extract/kg/ontology/text_processor.py new file mode 100644 index 00000000..98563bba --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/text_processor.py @@ -0,0 +1,325 @@ +""" +Text processing components for OntoRAG system. +Splits text into sentences and extracts phrases for granular matching. +""" + +import logging +import re +from typing import List, Dict, Any, Optional +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + +# Try to import NLTK for advanced text processing +try: + import nltk + NLTK_AVAILABLE = True + # Try to ensure required NLTK data is downloaded + try: + nltk.data.find('tokenizers/punkt') + except LookupError: + try: + nltk.download('punkt', quiet=True) + except: + pass + try: + nltk.data.find('taggers/averaged_perceptron_tagger') + except LookupError: + try: + nltk.download('averaged_perceptron_tagger', quiet=True) + except: + pass + try: + nltk.data.find('corpora/stopwords') + except LookupError: + try: + nltk.download('stopwords', quiet=True) + except: + pass +except ImportError: + NLTK_AVAILABLE = False + logger.warning("NLTK not available, using basic text processing") + + +@dataclass +class TextSegment: + """Represents a segment of text (sentence or phrase).""" + text: str + type: str # 'sentence', 'phrase', 'noun_phrase', 'verb_phrase' + position: int + parent_sentence: Optional[str] = None + metadata: Dict[str, Any] = None + + +class SentenceSplitter: + """Splits text into sentences using available NLP tools.""" + + def __init__(self): + """Initialize sentence splitter.""" + self.use_nltk = NLTK_AVAILABLE + if self.use_nltk: + try: + self.sent_detector = nltk.data.load('tokenizers/punkt/english.pickle') + logger.info("Using NLTK sentence tokenizer") + except: + self.use_nltk = False + logger.warning("NLTK punkt tokenizer not available, using regex") + + def split(self, text: str) -> List[str]: + """Split text into sentences. + + Args: + text: Text to split + + Returns: + List of sentences + """ + if self.use_nltk: + try: + sentences = self.sent_detector.tokenize(text) + return sentences + except Exception as e: + logger.warning(f"NLTK sentence splitting failed: {e}, falling back to regex") + + # Fallback to regex-based splitting + # Simple sentence boundary detection + sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text) + # Filter out empty sentences + sentences = [s.strip() for s in sentences if s.strip()] + return sentences + + +class PhraseExtractor: + """Extracts meaningful phrases from sentences.""" + + def __init__(self): + """Initialize phrase extractor.""" + self.use_nltk = NLTK_AVAILABLE + if self.use_nltk: + try: + # Test that POS tagger is available + nltk.pos_tag(['test']) + logger.info("Using NLTK phrase extraction") + except: + self.use_nltk = False + logger.warning("NLTK POS tagger not available, using basic extraction") + + def extract(self, sentence: str) -> List[Dict[str, str]]: + """Extract phrases from a sentence. + + Args: + sentence: Sentence to extract phrases from + + Returns: + List of phrases with their types + """ + phrases = [] + + if self.use_nltk: + try: + phrases.extend(self._extract_nltk_phrases(sentence)) + except Exception as e: + logger.warning(f"NLTK phrase extraction failed: {e}, using basic extraction") + phrases.extend(self._extract_basic_phrases(sentence)) + else: + phrases.extend(self._extract_basic_phrases(sentence)) + + return phrases + + def _extract_nltk_phrases(self, sentence: str) -> List[Dict[str, str]]: + """Extract phrases using NLTK. + + Args: + sentence: Sentence to process + + Returns: + List of phrases with types + """ + phrases = [] + + try: + # Tokenize and POS tag + tokens = nltk.word_tokenize(sentence) + pos_tags = nltk.pos_tag(tokens) + + # Extract noun phrases (simple pattern) + noun_phrase = [] + for word, pos in pos_tags: + if pos.startswith('NN') or pos.startswith('JJ'): + noun_phrase.append(word) + elif noun_phrase: + if len(noun_phrase) > 1: + phrases.append({ + 'text': ' '.join(noun_phrase), + 'type': 'noun_phrase' + }) + noun_phrase = [] + + # Add last noun phrase if exists + if noun_phrase and len(noun_phrase) > 1: + phrases.append({ + 'text': ' '.join(noun_phrase), + 'type': 'noun_phrase' + }) + + # Extract verb phrases (simple pattern) + verb_phrase = [] + for word, pos in pos_tags: + if pos.startswith('VB') or pos.startswith('RB'): + verb_phrase.append(word) + elif verb_phrase: + if len(verb_phrase) > 1: + phrases.append({ + 'text': ' '.join(verb_phrase), + 'type': 'verb_phrase' + }) + verb_phrase = [] + + # Add last verb phrase if exists + if verb_phrase and len(verb_phrase) > 1: + phrases.append({ + 'text': ' '.join(verb_phrase), + 'type': 'verb_phrase' + }) + + except Exception as e: + logger.error(f"Error in NLTK phrase extraction: {e}") + + return phrases + + def _extract_basic_phrases(self, sentence: str) -> List[Dict[str, str]]: + """Extract phrases using basic regex patterns. + + Args: + sentence: Sentence to process + + Returns: + List of phrases with types + """ + phrases = [] + + # Extract quoted phrases + quoted = re.findall(r'"([^"]+)"', sentence) + for q in quoted: + phrases.append({'text': q, 'type': 'phrase'}) + + # Extract parenthetical phrases + parens = re.findall(r'\(([^)]+)\)', sentence) + for p in parens: + phrases.append({'text': p, 'type': 'phrase'}) + + # Extract capitalized sequences (potential entities) + caps = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', sentence) + for c in caps: + if len(c.split()) > 1: # Multi-word entities + phrases.append({'text': c, 'type': 'noun_phrase'}) + + return phrases + + +class TextProcessor: + """Main text processing class that coordinates sentence splitting and phrase extraction.""" + + def __init__(self): + """Initialize text processor.""" + self.sentence_splitter = SentenceSplitter() + self.phrase_extractor = PhraseExtractor() + + def process_chunk(self, chunk_text: str, extract_phrases: bool = True) -> List[TextSegment]: + """Process a text chunk into segments. + + Args: + chunk_text: Text chunk to process + extract_phrases: Whether to extract phrases from sentences + + Returns: + List of TextSegment objects + """ + segments = [] + position = 0 + + # Split into sentences + sentences = self.sentence_splitter.split(chunk_text) + + for sentence in sentences: + # Add sentence segment + segments.append(TextSegment( + text=sentence, + type='sentence', + position=position + )) + position += 1 + + # Extract phrases if requested + if extract_phrases: + phrases = self.phrase_extractor.extract(sentence) + for phrase_data in phrases: + segments.append(TextSegment( + text=phrase_data['text'], + type=phrase_data['type'], + position=position, + parent_sentence=sentence + )) + position += 1 + + logger.debug(f"Processed chunk into {len(segments)} segments") + return segments + + def extract_key_terms(self, text: str) -> List[str]: + """Extract key terms from text for matching. + + Args: + text: Text to extract terms from + + Returns: + List of key terms + """ + terms = [] + + # Split on word boundaries + words = re.findall(r'\b\w+\b', text.lower()) + + # Filter common stop words (basic list) + stop_words = { + 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', + 'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were', 'be', + 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', + 'would', 'could', 'should', 'may', 'might', 'must', 'can', 'shall' + } + + # Use NLTK stopwords if available + if NLTK_AVAILABLE: + try: + from nltk.corpus import stopwords + stop_words = set(stopwords.words('english')) + except: + pass + + # Filter stopwords and short words + terms = [w for w in words if w not in stop_words and len(w) > 2] + + # Also extract multi-word terms (bigrams) + for i in range(len(words) - 1): + if words[i] not in stop_words and words[i+1] not in stop_words: + bigram = f"{words[i]} {words[i+1]}" + terms.append(bigram) + + return terms + + def normalize_text(self, text: str) -> str: + """Normalize text for consistent processing. + + Args: + text: Text to normalize + + Returns: + Normalized text + """ + # Remove extra whitespace + text = re.sub(r'\s+', ' ', text) + # Remove leading/trailing whitespace + text = text.strip() + # Normalize quotes + text = text.replace('"', '"').replace('"', '"') + text = text.replace(''', "'").replace(''', "'") + return text \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/vector_store.py b/trustgraph-flow/trustgraph/extract/kg/ontology/vector_store.py new file mode 100644 index 00000000..42c3dc7f --- /dev/null +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/vector_store.py @@ -0,0 +1,267 @@ +""" +Vector store implementations for OntoRAG system. +Provides both FAISS and NumPy-based vector storage for ontology embeddings. +""" + +import logging +import numpy as np +from typing import List, Dict, Any, Optional, Tuple +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + +# Try to import FAISS, fall back to NumPy implementation if not available +try: + import faiss + FAISS_AVAILABLE = True +except ImportError: + FAISS_AVAILABLE = False + logger.warning("FAISS not available, using NumPy implementation") + + +@dataclass +class SearchResult: + """Result from vector similarity search.""" + id: str + score: float + metadata: Dict[str, Any] + + +class VectorStore: + """Abstract base class for vector stores.""" + + def add(self, id: str, embedding: np.ndarray, metadata: Dict[str, Any]): + """Add single embedding with metadata.""" + raise NotImplementedError + + def add_batch(self, ids: List[str], embeddings: np.ndarray, + metadata_list: List[Dict[str, Any]]): + """Batch add for initial ontology loading.""" + raise NotImplementedError + + def search(self, embedding: np.ndarray, top_k: int = 10, + threshold: float = 0.0) -> List[SearchResult]: + """Search for similar vectors.""" + raise NotImplementedError + + def clear(self): + """Reset the store.""" + raise NotImplementedError + + def size(self) -> int: + """Return number of stored vectors.""" + raise NotImplementedError + + +class FAISSVectorStore(VectorStore): + """FAISS-based vector store implementation.""" + + def __init__(self, dimension: int = 1536, index_type: str = 'flat'): + """Initialize FAISS vector store. + + Args: + dimension: Embedding dimension (1536 for text-embedding-3-small) + index_type: 'flat' for exact search, 'ivf' for larger datasets + """ + if not FAISS_AVAILABLE: + raise RuntimeError("FAISS is not installed") + + self.dimension = dimension + self.metadata = [] + self.ids = [] + + if index_type == 'flat': + # Exact search - best for ontologies with <10k elements + self.index = faiss.IndexFlatIP(dimension) + logger.info(f"Created FAISS flat index with dimension {dimension}") + else: + # Approximate search - for larger ontologies + quantizer = faiss.IndexFlatIP(dimension) + self.index = faiss.IndexIVFFlat(quantizer, dimension, 100) + # Train with random vectors for initialization + training_data = np.random.randn(1000, dimension).astype('float32') + training_data = training_data / np.linalg.norm( + training_data, axis=1, keepdims=True + ) + self.index.train(training_data) + logger.info(f"Created FAISS IVF index with dimension {dimension}") + + def add(self, id: str, embedding: np.ndarray, metadata: Dict[str, Any]): + """Add single embedding with metadata.""" + # Normalize for cosine similarity + embedding = embedding / np.linalg.norm(embedding) + self.index.add(np.array([embedding], dtype=np.float32)) + self.metadata.append(metadata) + self.ids.append(id) + + def add_batch(self, ids: List[str], embeddings: np.ndarray, + metadata_list: List[Dict[str, Any]]): + """Batch add for initial ontology loading.""" + # Normalize all embeddings + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + normalized = embeddings / norms + self.index.add(normalized.astype(np.float32)) + self.metadata.extend(metadata_list) + self.ids.extend(ids) + logger.debug(f"Added batch of {len(ids)} embeddings to FAISS index") + + def search(self, embedding: np.ndarray, top_k: int = 10, + threshold: float = 0.0) -> List[SearchResult]: + """Search for similar vectors.""" + # Normalize query + embedding = embedding / np.linalg.norm(embedding) + + # Search + scores, indices = self.index.search( + np.array([embedding], dtype=np.float32), + min(top_k, self.index.ntotal) + ) + + # Filter by threshold and format results + results = [] + for score, idx in zip(scores[0], indices[0]): + if idx >= 0 and score >= threshold: # FAISS returns -1 for empty slots + results.append(SearchResult( + id=self.ids[idx], + score=float(score), + metadata=self.metadata[idx] + )) + + return results + + def clear(self): + """Reset the store.""" + self.index.reset() + self.metadata = [] + self.ids = [] + logger.info("Cleared FAISS vector store") + + def size(self) -> int: + """Return number of stored vectors.""" + return self.index.ntotal + + +class SimpleVectorStore(VectorStore): + """NumPy-based vector store implementation for development/small deployments.""" + + def __init__(self): + """Initialize simple NumPy-based vector store.""" + self.embeddings = [] + self.metadata = [] + self.ids = [] + logger.info("Created SimpleVectorStore (NumPy implementation)") + + def add(self, id: str, embedding: np.ndarray, metadata: Dict[str, Any]): + """Add single embedding with metadata.""" + # Normalize for cosine similarity + normalized = embedding / np.linalg.norm(embedding) + self.embeddings.append(normalized) + self.metadata.append(metadata) + self.ids.append(id) + + def add_batch(self, ids: List[str], embeddings: np.ndarray, + metadata_list: List[Dict[str, Any]]): + """Batch add for initial ontology loading.""" + # Normalize all embeddings + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + # Avoid division by zero + norms = np.where(norms == 0, 1, norms) + normalized = embeddings / norms + + for i in range(len(ids)): + self.embeddings.append(normalized[i]) + self.metadata.append(metadata_list[i]) + self.ids.append(ids[i]) + + logger.debug(f"Added batch of {len(ids)} embeddings to simple store") + + def search(self, embedding: np.ndarray, top_k: int = 10, + threshold: float = 0.0) -> List[SearchResult]: + """Search for similar vectors using cosine similarity.""" + if not self.embeddings: + return [] + + # Normalize query embedding + embedding = embedding / np.linalg.norm(embedding) + + # Compute cosine similarities + embeddings_array = np.array(self.embeddings) + similarities = np.dot(embeddings_array, embedding) + + # Get top-k indices + top_k = min(top_k, len(self.embeddings)) + top_indices = np.argsort(similarities)[::-1][:top_k] + + # Build results + results = [] + for idx in top_indices: + if similarities[idx] >= threshold: + results.append(SearchResult( + id=self.ids[idx], + score=float(similarities[idx]), + metadata=self.metadata[idx] + )) + + return results + + def clear(self): + """Reset the store.""" + self.embeddings = [] + self.metadata = [] + self.ids = [] + logger.info("Cleared simple vector store") + + def size(self) -> int: + """Return number of stored vectors.""" + return len(self.embeddings) + + +class InMemoryVectorStore: + """Factory class to create appropriate vector store based on availability.""" + + @staticmethod + def create(dimension: int = 1536, prefer_faiss: bool = True, + index_type: str = 'flat') -> VectorStore: + """Create a vector store instance. + + Args: + dimension: Embedding dimension + prefer_faiss: Whether to prefer FAISS if available + index_type: Type of FAISS index ('flat' or 'ivf') + + Returns: + VectorStore instance (FAISS or Simple) + """ + if prefer_faiss and FAISS_AVAILABLE: + try: + return FAISSVectorStore(dimension, index_type) + except Exception as e: + logger.warning(f"Failed to create FAISS store: {e}, falling back to NumPy") + return SimpleVectorStore() + else: + return SimpleVectorStore() + + +# Utility functions for vector operations +def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: + """Compute cosine similarity between two vectors.""" + return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) + + +def batch_cosine_similarity(queries: np.ndarray, targets: np.ndarray) -> np.ndarray: + """Compute cosine similarity between query vectors and target vectors. + + Args: + queries: Array of shape (n_queries, dimension) + targets: Array of shape (n_targets, dimension) + + Returns: + Array of shape (n_queries, n_targets) with similarity scores + """ + # Normalize queries and targets + queries_norm = queries / np.linalg.norm(queries, axis=1, keepdims=True) + targets_norm = targets / np.linalg.norm(targets, axis=1, keepdims=True) + + # Compute dot product + similarities = np.dot(queries_norm, targets_norm.T) + return similarities \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/__init__.py b/trustgraph-flow/trustgraph/query/ontology/__init__.py new file mode 100644 index 00000000..60557ea9 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/__init__.py @@ -0,0 +1,54 @@ +""" +OntoRAG Query System. + +Ontology-driven natural language query processing with multi-backend support. +Provides semantic query understanding, ontology matching, and answer generation. +""" + +from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse +from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType +from .ontology_matcher import OntologyMatcher, QueryOntologySubset +from .backend_router import BackendRouter, BackendType, QueryRoute +from .sparql_generator import SPARQLGenerator, SPARQLQuery +from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult +from .cypher_generator import CypherGenerator, CypherQuery +from .cypher_executor import CypherExecutor, CypherResult +from .answer_generator import AnswerGenerator, GeneratedAnswer, AnswerMetadata + +__all__ = [ + # Main service + 'OntoRAGQueryService', + 'QueryRequest', + 'QueryResponse', + + # Question analysis + 'QuestionAnalyzer', + 'QuestionComponents', + 'QuestionType', + + # Ontology matching + 'OntologyMatcher', + 'QueryOntologySubset', + + # Backend routing + 'BackendRouter', + 'BackendType', + 'QueryRoute', + + # SPARQL components + 'SPARQLGenerator', + 'SPARQLQuery', + 'SPARQLCassandraEngine', + 'SPARQLResult', + + # Cypher components + 'CypherGenerator', + 'CypherQuery', + 'CypherExecutor', + 'CypherResult', + + # Answer generation + 'AnswerGenerator', + 'GeneratedAnswer', + 'AnswerMetadata', +] \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/answer_generator.py b/trustgraph-flow/trustgraph/query/ontology/answer_generator.py new file mode 100644 index 00000000..9b4b6ba7 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/answer_generator.py @@ -0,0 +1,521 @@ +""" +Answer generator for natural language responses. +Converts query results into natural language answers using LLM assistance. +""" + +import logging +from typing import Dict, Any, List, Optional, Union +from dataclasses import dataclass +from datetime import datetime + +from .question_analyzer import QuestionComponents, QuestionType +from .ontology_matcher import QueryOntologySubset +from .sparql_cassandra import SPARQLResult +from .cypher_executor import CypherResult + +logger = logging.getLogger(__name__) + + +@dataclass +class AnswerMetadata: + """Metadata about answer generation.""" + query_type: str + backend_used: str + execution_time: float + result_count: int + confidence: float + explanation: str + sources: List[str] + + +@dataclass +class GeneratedAnswer: + """Generated natural language answer.""" + answer: str + metadata: AnswerMetadata + supporting_facts: List[str] + raw_results: Union[SPARQLResult, CypherResult] + generation_time: float + + +class AnswerGenerator: + """Generates natural language answers from query results.""" + + def __init__(self, prompt_service=None): + """Initialize answer generator. + + Args: + prompt_service: Service for LLM-based answer generation + """ + self.prompt_service = prompt_service + + # Answer templates for different question types + self.templates = { + 'count': "There are {count} {entity_type}.", + 'boolean_true': "Yes, {statement} is true.", + 'boolean_false': "No, {statement} is not true.", + 'list': "The {entity_type} are: {items}.", + 'single': "The {property} of {entity} is {value}.", + 'none': "No results were found for your query.", + 'error': "I encountered an error processing your query: {error}" + } + + async def generate_answer(self, + question_components: QuestionComponents, + query_results: Union[SPARQLResult, CypherResult], + ontology_subset: QueryOntologySubset, + backend_used: str) -> GeneratedAnswer: + """Generate natural language answer from query results. + + Args: + question_components: Original question analysis + query_results: Results from query execution + ontology_subset: Ontology subset used + backend_used: Backend that executed the query + + Returns: + Generated answer with metadata + """ + start_time = datetime.now() + + try: + # Try LLM-based generation first + if self.prompt_service: + llm_answer = await self._generate_with_llm( + question_components, query_results, ontology_subset + ) + if llm_answer: + execution_time = (datetime.now() - start_time).total_seconds() + return self._build_answer_response( + llm_answer, question_components, query_results, + backend_used, execution_time + ) + + # Fall back to template-based generation + template_answer = self._generate_with_template( + question_components, query_results, ontology_subset + ) + + execution_time = (datetime.now() - start_time).total_seconds() + return self._build_answer_response( + template_answer, question_components, query_results, + backend_used, execution_time + ) + + except Exception as e: + logger.error(f"Answer generation failed: {e}") + execution_time = (datetime.now() - start_time).total_seconds() + error_answer = self.templates['error'].format(error=str(e)) + return self._build_answer_response( + error_answer, question_components, query_results, + backend_used, execution_time, confidence=0.0 + ) + + async def _generate_with_llm(self, + question_components: QuestionComponents, + query_results: Union[SPARQLResult, CypherResult], + ontology_subset: QueryOntologySubset) -> Optional[str]: + """Generate answer using LLM. + + Args: + question_components: Question analysis + query_results: Query results + ontology_subset: Ontology subset + + Returns: + Generated answer or None if failed + """ + try: + prompt = self._build_answer_prompt( + question_components, query_results, ontology_subset + ) + response = await self.prompt_service.generate_answer(prompt=prompt) + + if response and isinstance(response, dict): + return response.get('answer', '').strip() + elif isinstance(response, str): + return response.strip() + + except Exception as e: + logger.error(f"LLM answer generation failed: {e}") + + return None + + def _generate_with_template(self, + question_components: QuestionComponents, + query_results: Union[SPARQLResult, CypherResult], + ontology_subset: QueryOntologySubset) -> str: + """Generate answer using templates. + + Args: + question_components: Question analysis + query_results: Query results + ontology_subset: Ontology subset + + Returns: + Template-based answer + """ + # Handle empty results + if not self._has_results(query_results): + return self.templates['none'] + + # Handle boolean queries + if question_components.question_type == QuestionType.BOOLEAN: + if hasattr(query_results, 'ask_result'): + # SPARQL ASK result + statement = self._extract_boolean_statement(question_components) + if query_results.ask_result: + return self.templates['boolean_true'].format(statement=statement) + else: + return self.templates['boolean_false'].format(statement=statement) + else: + # Cypher boolean (check if any results) + has_results = len(query_results.records) > 0 + statement = self._extract_boolean_statement(question_components) + if has_results: + return self.templates['boolean_true'].format(statement=statement) + else: + return self.templates['boolean_false'].format(statement=statement) + + # Handle count queries + if question_components.question_type == QuestionType.AGGREGATION: + count = self._extract_count(query_results) + entity_type = self._infer_entity_type(question_components, ontology_subset) + return self.templates['count'].format(count=count, entity_type=entity_type) + + # Handle retrieval queries + if question_components.question_type == QuestionType.RETRIEVAL: + items = self._extract_items(query_results) + if len(items) == 1: + # Single result + entity = question_components.entities[0] if question_components.entities else "entity" + property_name = "value" + return self.templates['single'].format( + property=property_name, entity=entity, value=items[0] + ) + else: + # Multiple results + entity_type = self._infer_entity_type(question_components, ontology_subset) + items_str = ", ".join(items) + return self.templates['list'].format(entity_type=entity_type, items=items_str) + + # Handle factual queries + if question_components.question_type == QuestionType.FACTUAL: + facts = self._extract_facts(query_results) + return ". ".join(facts) if facts else self.templates['none'] + + # Default fallback + items = self._extract_items(query_results) + if items: + return f"Found: {', '.join(items[:5])}" + ("..." if len(items) > 5 else "") + else: + return self.templates['none'] + + def _build_answer_prompt(self, + question_components: QuestionComponents, + query_results: Union[SPARQLResult, CypherResult], + ontology_subset: QueryOntologySubset) -> str: + """Build prompt for LLM answer generation. + + Args: + question_components: Question analysis + query_results: Query results + ontology_subset: Ontology subset + + Returns: + Formatted prompt string + """ + # Format results for prompt + results_str = self._format_results_for_prompt(query_results) + + # Extract ontology context + context_classes = list(ontology_subset.classes.keys())[:5] + context_properties = list(ontology_subset.object_properties.keys())[:5] + + prompt = f"""Generate a natural language answer for the following question based on the query results. + +ORIGINAL QUESTION: {question_components.original_question} + +QUESTION TYPE: {question_components.question_type.value} +EXPECTED ANSWER: {question_components.expected_answer_type} + +ONTOLOGY CONTEXT: +- Classes: {', '.join(context_classes)} +- Properties: {', '.join(context_properties)} + +QUERY RESULTS: +{results_str} + +INSTRUCTIONS: +- Provide a clear, concise answer in natural language +- Use the original question's tone and style +- Include specific facts from the results +- If no results, explain that no information was found +- Be accurate and don't make assumptions beyond the data +- Limit response to 2-3 sentences unless the question requires more detail + +ANSWER:""" + + return prompt + + def _format_results_for_prompt(self, query_results: Union[SPARQLResult, CypherResult]) -> str: + """Format query results for prompt inclusion. + + Args: + query_results: Query results to format + + Returns: + Formatted results string + """ + if isinstance(query_results, SPARQLResult): + if hasattr(query_results, 'ask_result') and query_results.ask_result is not None: + return f"Boolean result: {query_results.ask_result}" + + if not query_results.bindings: + return "No results found" + + # Format SPARQL bindings + lines = [] + for binding in query_results.bindings[:10]: # Limit to first 10 + formatted = [] + for var, value in binding.items(): + if isinstance(value, dict): + formatted.append(f"{var}: {value.get('value', value)}") + else: + formatted.append(f"{var}: {value}") + lines.append("- " + ", ".join(formatted)) + + if len(query_results.bindings) > 10: + lines.append(f"... and {len(query_results.bindings) - 10} more results") + + return "\n".join(lines) + + else: # CypherResult + if not query_results.records: + return "No results found" + + # Format Cypher records + lines = [] + for record in query_results.records[:10]: # Limit to first 10 + if isinstance(record, dict): + formatted = [f"{k}: {v}" for k, v in record.items()] + lines.append("- " + ", ".join(formatted)) + else: + lines.append(f"- {record}") + + if len(query_results.records) > 10: + lines.append(f"... and {len(query_results.records) - 10} more results") + + return "\n".join(lines) + + def _has_results(self, query_results: Union[SPARQLResult, CypherResult]) -> bool: + """Check if query results contain data. + + Args: + query_results: Query results to check + + Returns: + True if results contain data + """ + if isinstance(query_results, SPARQLResult): + return bool(query_results.bindings) or query_results.ask_result is not None + else: # CypherResult + return bool(query_results.records) + + def _extract_count(self, query_results: Union[SPARQLResult, CypherResult]) -> int: + """Extract count from aggregation query results. + + Args: + query_results: Query results + + Returns: + Count value + """ + if isinstance(query_results, SPARQLResult): + if query_results.bindings: + binding = query_results.bindings[0] + # Look for count variable + for var, value in binding.items(): + if 'count' in var.lower(): + if isinstance(value, dict): + return int(value.get('value', 0)) + return int(value) + return len(query_results.bindings) + else: # CypherResult + if query_results.records: + record = query_results.records[0] + if isinstance(record, dict): + # Look for count key + for key, value in record.items(): + if 'count' in key.lower(): + return int(value) + elif isinstance(record, (int, float)): + return int(record) + return len(query_results.records) + + def _extract_items(self, query_results: Union[SPARQLResult, CypherResult]) -> List[str]: + """Extract items from query results. + + Args: + query_results: Query results + + Returns: + List of extracted items + """ + items = [] + + if isinstance(query_results, SPARQLResult): + for binding in query_results.bindings: + for var, value in binding.items(): + if isinstance(value, dict): + item_value = value.get('value', str(value)) + else: + item_value = str(value) + + # Clean up URIs + if item_value.startswith('http'): + item_value = item_value.split('/')[-1].split('#')[-1] + + items.append(item_value) + break # Take first value per binding + + else: # CypherResult + for record in query_results.records: + if isinstance(record, dict): + # Take first value from record + for key, value in record.items(): + items.append(str(value)) + break + else: + items.append(str(record)) + + return items + + def _extract_facts(self, query_results: Union[SPARQLResult, CypherResult]) -> List[str]: + """Extract facts from query results. + + Args: + query_results: Query results + + Returns: + List of facts + """ + facts = [] + + if isinstance(query_results, SPARQLResult): + for binding in query_results.bindings: + fact_parts = [] + for var, value in binding.items(): + if isinstance(value, dict): + val_str = value.get('value', str(value)) + else: + val_str = str(value) + + # Clean up URIs + if val_str.startswith('http'): + val_str = val_str.split('/')[-1].split('#')[-1] + + fact_parts.append(f"{var}: {val_str}") + + facts.append(", ".join(fact_parts)) + + else: # CypherResult + for record in query_results.records: + if isinstance(record, dict): + fact_parts = [f"{k}: {v}" for k, v in record.items()] + facts.append(", ".join(fact_parts)) + else: + facts.append(str(record)) + + return facts + + def _extract_boolean_statement(self, question_components: QuestionComponents) -> str: + """Extract statement for boolean answer. + + Args: + question_components: Question analysis + + Returns: + Statement string + """ + # Extract the key assertion from the question + question = question_components.original_question.lower() + + # Remove question words + statement = question.replace('is ', '').replace('are ', '').replace('does ', '') + statement = statement.replace('?', '').strip() + + return statement + + def _infer_entity_type(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> str: + """Infer entity type from question and ontology. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Entity type string + """ + # Try to match entities to ontology classes + for entity in question_components.entities: + entity_lower = entity.lower() + for class_id in ontology_subset.classes: + if class_id.lower() == entity_lower or entity_lower in class_id.lower(): + return class_id + + # Fallback to first entity or generic term + if question_components.entities: + return question_components.entities[0] + else: + return "entities" + + def _build_answer_response(self, + answer: str, + question_components: QuestionComponents, + query_results: Union[SPARQLResult, CypherResult], + backend_used: str, + execution_time: float, + confidence: float = 0.8) -> GeneratedAnswer: + """Build final answer response. + + Args: + answer: Generated answer text + question_components: Question analysis + query_results: Query results + backend_used: Backend used for query + execution_time: Answer generation time + confidence: Confidence score + + Returns: + Complete answer response + """ + # Extract supporting facts + supporting_facts = self._extract_facts(query_results) + + # Build metadata + result_count = 0 + if isinstance(query_results, SPARQLResult): + result_count = len(query_results.bindings) + else: # CypherResult + result_count = len(query_results.records) + + metadata = AnswerMetadata( + query_type=question_components.question_type.value, + backend_used=backend_used, + execution_time=execution_time, + result_count=result_count, + confidence=confidence, + explanation=f"Generated answer using {backend_used} backend", + sources=[] # Could be populated with data source information + ) + + return GeneratedAnswer( + answer=answer, + metadata=metadata, + supporting_facts=supporting_facts[:5], # Limit to top 5 + raw_results=query_results, + generation_time=execution_time + ) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/backend_router.py b/trustgraph-flow/trustgraph/query/ontology/backend_router.py new file mode 100644 index 00000000..cbd23530 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/backend_router.py @@ -0,0 +1,350 @@ +""" +Backend router for ontology query system. +Routes queries to appropriate backend based on configuration. +""" + +import logging +from typing import Dict, Any, Optional, List +from dataclasses import dataclass +from enum import Enum + +from .question_analyzer import QuestionComponents +from .ontology_matcher import QueryOntologySubset + +logger = logging.getLogger(__name__) + + +class BackendType(Enum): + """Supported backend types.""" + CASSANDRA = "cassandra" + NEO4J = "neo4j" + MEMGRAPH = "memgraph" + FALKORDB = "falkordb" + + +@dataclass +class BackendConfig: + """Configuration for a backend.""" + type: BackendType + priority: int = 0 + enabled: bool = True + config: Dict[str, Any] = None + + +@dataclass +class QueryRoute: + """Routing decision for a query.""" + backend_type: BackendType + query_language: str # 'sparql' or 'cypher' + confidence: float + reasoning: str + + +class BackendRouter: + """Routes queries to appropriate backends based on configuration and heuristics.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize backend router. + + Args: + config: Router configuration + """ + self.config = config + self.backends = self._parse_backend_config(config) + self.routing_strategy = config.get('routing_strategy', 'priority') + self.enable_fallback = config.get('enable_fallback', True) + + def _parse_backend_config(self, config: Dict[str, Any]) -> Dict[BackendType, BackendConfig]: + """Parse backend configuration. + + Args: + config: Configuration dictionary + + Returns: + Dictionary of backend type to configuration + """ + backends = {} + + # Parse primary backend + primary = config.get('primary', 'cassandra') + if primary: + try: + backend_type = BackendType(primary) + backends[backend_type] = BackendConfig( + type=backend_type, + priority=100, + enabled=True, + config=config.get(primary, {}) + ) + except ValueError: + logger.warning(f"Unknown primary backend type: {primary}") + + # Parse fallback backends + fallbacks = config.get('fallback', []) + for i, fallback in enumerate(fallbacks): + try: + backend_type = BackendType(fallback) + backends[backend_type] = BackendConfig( + type=backend_type, + priority=50 - i * 10, # Decreasing priority + enabled=True, + config=config.get(fallback, {}) + ) + except ValueError: + logger.warning(f"Unknown fallback backend type: {fallback}") + + return backends + + def route_query(self, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset]) -> QueryRoute: + """Route a query to the best backend. + + Args: + question_components: Analyzed question + ontology_subsets: Relevant ontology subsets + + Returns: + QueryRoute with routing decision + """ + if self.routing_strategy == 'priority': + return self._route_by_priority() + elif self.routing_strategy == 'adaptive': + return self._route_adaptive(question_components, ontology_subsets) + elif self.routing_strategy == 'round_robin': + return self._route_round_robin() + else: + return self._route_by_priority() + + def _route_by_priority(self) -> QueryRoute: + """Route based on backend priority. + + Returns: + QueryRoute to highest priority backend + """ + # Find highest priority enabled backend + best_backend = None + best_priority = -1 + + for backend_type, backend_config in self.backends.items(): + if backend_config.enabled and backend_config.priority > best_priority: + best_backend = backend_type + best_priority = backend_config.priority + + if best_backend is None: + raise RuntimeError("No enabled backends available") + + # Determine query language + query_language = 'sparql' if best_backend == BackendType.CASSANDRA else 'cypher' + + return QueryRoute( + backend_type=best_backend, + query_language=query_language, + confidence=1.0, + reasoning=f"Priority routing to {best_backend.value}" + ) + + def _route_adaptive(self, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset]) -> QueryRoute: + """Route based on question characteristics and ontology complexity. + + Args: + question_components: Analyzed question + ontology_subsets: Relevant ontology subsets + + Returns: + QueryRoute with adaptive decision + """ + scores = {} + + for backend_type, backend_config in self.backends.items(): + if not backend_config.enabled: + continue + + score = self._calculate_backend_score( + backend_type, question_components, ontology_subsets + ) + scores[backend_type] = score + + if not scores: + raise RuntimeError("No enabled backends available") + + # Select backend with highest score + best_backend = max(scores.keys(), key=lambda k: scores[k]) + best_score = scores[best_backend] + + # Determine query language + query_language = 'sparql' if best_backend == BackendType.CASSANDRA else 'cypher' + + return QueryRoute( + backend_type=best_backend, + query_language=query_language, + confidence=best_score, + reasoning=f"Adaptive routing: {best_backend.value} scored {best_score:.2f}" + ) + + def _calculate_backend_score(self, + backend_type: BackendType, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset]) -> float: + """Calculate score for a backend based on query characteristics. + + Args: + backend_type: Backend to score + question_components: Question analysis + ontology_subsets: Ontology subsets + + Returns: + Score (0.0 to 1.0) + """ + score = 0.0 + + # Base priority score + backend_config = self.backends[backend_type] + score += backend_config.priority / 100.0 + + # Question type preferences + if backend_type == BackendType.CASSANDRA: + # SPARQL is good for hierarchical and complex reasoning + if question_components.question_type.value in ['factual', 'aggregation']: + score += 0.3 + # Good for ontology-heavy queries + if len(ontology_subsets) > 1: + score += 0.2 + else: + # Cypher is good for graph traversal and relationships + if question_components.question_type.value in ['relationship', 'retrieval']: + score += 0.3 + # Good for simple graph patterns + if len(question_components.relationships) > 0: + score += 0.2 + + # Complexity considerations + total_elements = sum( + len(subset.classes) + len(subset.object_properties) + len(subset.datatype_properties) + for subset in ontology_subsets + ) + + if backend_type == BackendType.CASSANDRA: + # SPARQL handles complex ontologies well + if total_elements > 20: + score += 0.2 + else: + # Cypher is efficient for simpler queries + if total_elements <= 10: + score += 0.2 + + # Aggregation considerations + if question_components.aggregations: + if backend_type == BackendType.CASSANDRA: + score += 0.1 # SPARQL has built-in aggregation + else: + score += 0.2 # Cypher has excellent aggregation + + return min(score, 1.0) + + def _route_round_robin(self) -> QueryRoute: + """Route using round-robin strategy. + + Returns: + QueryRoute using round-robin selection + """ + # Simple round-robin implementation + enabled_backends = [ + bt for bt, bc in self.backends.items() if bc.enabled + ] + + if not enabled_backends: + raise RuntimeError("No enabled backends available") + + # For simplicity, just return the first enabled backend + # In a real implementation, you'd track state + backend_type = enabled_backends[0] + query_language = 'sparql' if backend_type == BackendType.CASSANDRA else 'cypher' + + return QueryRoute( + backend_type=backend_type, + query_language=query_language, + confidence=0.8, + reasoning=f"Round-robin routing to {backend_type.value}" + ) + + def get_fallback_route(self, failed_backend: BackendType) -> Optional[QueryRoute]: + """Get fallback route when a backend fails. + + Args: + failed_backend: Backend that failed + + Returns: + Fallback route or None if no fallback available + """ + if not self.enable_fallback: + return None + + # Find next best backend + fallback_backends = [ + (bt, bc) for bt, bc in self.backends.items() + if bc.enabled and bt != failed_backend + ] + + if not fallback_backends: + return None + + # Sort by priority + fallback_backends.sort(key=lambda x: x[1].priority, reverse=True) + fallback_type = fallback_backends[0][0] + + query_language = 'sparql' if fallback_type == BackendType.CASSANDRA else 'cypher' + + return QueryRoute( + backend_type=fallback_type, + query_language=query_language, + confidence=0.7, + reasoning=f"Fallback from {failed_backend.value} to {fallback_type.value}" + ) + + def get_available_backends(self) -> List[BackendType]: + """Get list of available backends. + + Returns: + List of enabled backend types + """ + return [bt for bt, bc in self.backends.items() if bc.enabled] + + def is_backend_enabled(self, backend_type: BackendType) -> bool: + """Check if a backend is enabled. + + Args: + backend_type: Backend to check + + Returns: + True if backend is enabled + """ + backend_config = self.backends.get(backend_type) + return backend_config is not None and backend_config.enabled + + def update_backend_status(self, backend_type: BackendType, enabled: bool): + """Update backend enabled status. + + Args: + backend_type: Backend to update + enabled: New enabled status + """ + if backend_type in self.backends: + self.backends[backend_type].enabled = enabled + logger.info(f"Backend {backend_type.value} {'enabled' if enabled else 'disabled'}") + else: + logger.warning(f"Unknown backend type: {backend_type}") + + def get_backend_config(self, backend_type: BackendType) -> Optional[Dict[str, Any]]: + """Get configuration for a backend. + + Args: + backend_type: Backend type + + Returns: + Configuration dictionary or None + """ + backend_config = self.backends.get(backend_type) + return backend_config.config if backend_config else None \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/cache.py b/trustgraph-flow/trustgraph/query/ontology/cache.py new file mode 100644 index 00000000..266bd805 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/cache.py @@ -0,0 +1,651 @@ +""" +Caching system for OntoRAG query results and computations. +Provides multiple cache backends and intelligent cache management. +""" + +import logging +import time +import json +import pickle +from typing import Dict, Any, List, Optional, Union, Tuple +from dataclasses import dataclass, asdict +from datetime import datetime, timedelta +from abc import ABC, abstractmethod +from pathlib import Path +import threading + +logger = logging.getLogger(__name__) + + +@dataclass +class CacheEntry: + """Cache entry with metadata.""" + key: str + value: Any + created_at: datetime + accessed_at: datetime + access_count: int + ttl_seconds: Optional[int] = None + tags: List[str] = None + size_bytes: int = 0 + + def is_expired(self) -> bool: + """Check if cache entry is expired.""" + if self.ttl_seconds is None: + return False + return (datetime.now() - self.created_at).total_seconds() > self.ttl_seconds + + def touch(self): + """Update access time and count.""" + self.accessed_at = datetime.now() + self.access_count += 1 + + +@dataclass +class CacheStats: + """Cache performance statistics.""" + hits: int = 0 + misses: int = 0 + evictions: int = 0 + total_entries: int = 0 + total_size_bytes: int = 0 + hit_rate: float = 0.0 + + def update_hit_rate(self): + """Update hit rate calculation.""" + total_requests = self.hits + self.misses + self.hit_rate = self.hits / total_requests if total_requests > 0 else 0.0 + + +class CacheBackend(ABC): + """Abstract base class for cache backends.""" + + @abstractmethod + def get(self, key: str) -> Optional[CacheEntry]: + """Get cache entry by key.""" + pass + + @abstractmethod + def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None): + """Set cache entry.""" + pass + + @abstractmethod + def delete(self, key: str) -> bool: + """Delete cache entry.""" + pass + + @abstractmethod + def clear(self, tags: Optional[List[str]] = None): + """Clear cache entries.""" + pass + + @abstractmethod + def get_stats(self) -> CacheStats: + """Get cache statistics.""" + pass + + @abstractmethod + def cleanup_expired(self): + """Clean up expired entries.""" + pass + + +class InMemoryCache(CacheBackend): + """In-memory cache backend.""" + + def __init__(self, max_size: int = 1000, max_size_bytes: int = 100 * 1024 * 1024): + """Initialize in-memory cache. + + Args: + max_size: Maximum number of entries + max_size_bytes: Maximum total size in bytes + """ + self.max_size = max_size + self.max_size_bytes = max_size_bytes + self.entries: Dict[str, CacheEntry] = {} + self.stats = CacheStats() + self._lock = threading.RLock() + + def get(self, key: str) -> Optional[CacheEntry]: + """Get cache entry by key.""" + with self._lock: + entry = self.entries.get(key) + if entry is None: + self.stats.misses += 1 + self.stats.update_hit_rate() + return None + + if entry.is_expired(): + del self.entries[key] + self.stats.misses += 1 + self.stats.evictions += 1 + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= entry.size_bytes + self.stats.update_hit_rate() + return None + + entry.touch() + self.stats.hits += 1 + self.stats.update_hit_rate() + return entry + + def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None): + """Set cache entry.""" + with self._lock: + # Calculate size + try: + size_bytes = len(pickle.dumps(value)) + except Exception: + size_bytes = len(str(value).encode('utf-8')) + + # Create entry + now = datetime.now() + entry = CacheEntry( + key=key, + value=value, + created_at=now, + accessed_at=now, + access_count=1, + ttl_seconds=ttl_seconds, + tags=tags or [], + size_bytes=size_bytes + ) + + # Check if we need to evict + self._ensure_capacity(size_bytes) + + # Store entry + old_entry = self.entries.get(key) + if old_entry: + self.stats.total_size_bytes -= old_entry.size_bytes + else: + self.stats.total_entries += 1 + + self.entries[key] = entry + self.stats.total_size_bytes += size_bytes + + def delete(self, key: str) -> bool: + """Delete cache entry.""" + with self._lock: + entry = self.entries.pop(key, None) + if entry: + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= entry.size_bytes + self.stats.evictions += 1 + return True + return False + + def clear(self, tags: Optional[List[str]] = None): + """Clear cache entries.""" + with self._lock: + if tags is None: + # Clear all + self.stats.evictions += len(self.entries) + self.entries.clear() + self.stats.total_entries = 0 + self.stats.total_size_bytes = 0 + else: + # Clear by tags + to_delete = [] + for key, entry in self.entries.items(): + if any(tag in entry.tags for tag in tags): + to_delete.append(key) + + for key in to_delete: + self.delete(key) + + def get_stats(self) -> CacheStats: + """Get cache statistics.""" + with self._lock: + return CacheStats( + hits=self.stats.hits, + misses=self.stats.misses, + evictions=self.stats.evictions, + total_entries=self.stats.total_entries, + total_size_bytes=self.stats.total_size_bytes, + hit_rate=self.stats.hit_rate + ) + + def cleanup_expired(self): + """Clean up expired entries.""" + with self._lock: + to_delete = [] + for key, entry in self.entries.items(): + if entry.is_expired(): + to_delete.append(key) + + for key in to_delete: + self.delete(key) + + def _ensure_capacity(self, new_size_bytes: int): + """Ensure cache has capacity for new entry.""" + # Check size limit + if self.stats.total_size_bytes + new_size_bytes > self.max_size_bytes: + self._evict_by_size(new_size_bytes) + + # Check count limit + if len(self.entries) >= self.max_size: + self._evict_by_count() + + def _evict_by_size(self, needed_bytes: int): + """Evict entries to free up space.""" + # Sort by access time (LRU) + sorted_entries = sorted( + self.entries.items(), + key=lambda x: (x[1].accessed_at, x[1].access_count) + ) + + freed_bytes = 0 + for key, entry in sorted_entries: + if freed_bytes >= needed_bytes: + break + freed_bytes += entry.size_bytes + del self.entries[key] + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= entry.size_bytes + self.stats.evictions += 1 + + def _evict_by_count(self): + """Evict least recently used entry.""" + if not self.entries: + return + + # Find LRU entry + lru_key = min( + self.entries.keys(), + key=lambda k: (self.entries[k].accessed_at, self.entries[k].access_count) + ) + self.delete(lru_key) + + +class FileCache(CacheBackend): + """File-based cache backend.""" + + def __init__(self, cache_dir: str, max_files: int = 10000): + """Initialize file cache. + + Args: + cache_dir: Directory to store cache files + max_files: Maximum number of cache files + """ + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.max_files = max_files + self.stats = CacheStats() + self._lock = threading.RLock() + + # Load existing stats + self._load_stats() + + def get(self, key: str) -> Optional[CacheEntry]: + """Get cache entry by key.""" + with self._lock: + cache_file = self.cache_dir / f"{self._safe_key(key)}.cache" + if not cache_file.exists(): + self.stats.misses += 1 + self.stats.update_hit_rate() + return None + + try: + with open(cache_file, 'rb') as f: + entry = pickle.load(f) + + if entry.is_expired(): + cache_file.unlink() + self.stats.misses += 1 + self.stats.evictions += 1 + self.stats.total_entries -= 1 + self.stats.update_hit_rate() + return None + + entry.touch() + # Update file modification time + cache_file.touch() + + self.stats.hits += 1 + self.stats.update_hit_rate() + return entry + + except Exception as e: + logger.error(f"Error reading cache file {cache_file}: {e}") + cache_file.unlink(missing_ok=True) + self.stats.misses += 1 + self.stats.update_hit_rate() + return None + + def set(self, key: str, value: Any, ttl_seconds: Optional[int] = None, tags: List[str] = None): + """Set cache entry.""" + with self._lock: + cache_file = self.cache_dir / f"{self._safe_key(key)}.cache" + + # Create entry + now = datetime.now() + entry = CacheEntry( + key=key, + value=value, + created_at=now, + accessed_at=now, + access_count=1, + ttl_seconds=ttl_seconds, + tags=tags or [] + ) + + try: + # Ensure capacity + self._ensure_capacity() + + # Write to file + with open(cache_file, 'wb') as f: + pickle.dump(entry, f) + + entry.size_bytes = cache_file.stat().st_size + + if not cache_file.exists(): + self.stats.total_entries += 1 + + self.stats.total_size_bytes += entry.size_bytes + self._save_stats() + + except Exception as e: + logger.error(f"Error writing cache file {cache_file}: {e}") + + def delete(self, key: str) -> bool: + """Delete cache entry.""" + with self._lock: + cache_file = self.cache_dir / f"{self._safe_key(key)}.cache" + if cache_file.exists(): + size = cache_file.stat().st_size + cache_file.unlink() + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= size + self.stats.evictions += 1 + self._save_stats() + return True + return False + + def clear(self, tags: Optional[List[str]] = None): + """Clear cache entries.""" + with self._lock: + if tags is None: + # Clear all + for cache_file in self.cache_dir.glob("*.cache"): + cache_file.unlink() + self.stats.evictions += self.stats.total_entries + self.stats.total_entries = 0 + self.stats.total_size_bytes = 0 + else: + # Clear by tags + for cache_file in self.cache_dir.glob("*.cache"): + try: + with open(cache_file, 'rb') as f: + entry = pickle.load(f) + if any(tag in entry.tags for tag in tags): + size = cache_file.stat().st_size + cache_file.unlink() + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= size + self.stats.evictions += 1 + except Exception: + continue + + self._save_stats() + + def get_stats(self) -> CacheStats: + """Get cache statistics.""" + with self._lock: + return CacheStats( + hits=self.stats.hits, + misses=self.stats.misses, + evictions=self.stats.evictions, + total_entries=self.stats.total_entries, + total_size_bytes=self.stats.total_size_bytes, + hit_rate=self.stats.hit_rate + ) + + def cleanup_expired(self): + """Clean up expired entries.""" + with self._lock: + for cache_file in self.cache_dir.glob("*.cache"): + try: + with open(cache_file, 'rb') as f: + entry = pickle.load(f) + if entry.is_expired(): + size = cache_file.stat().st_size + cache_file.unlink() + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= size + self.stats.evictions += 1 + except Exception: + # Remove corrupted files + cache_file.unlink() + + self._save_stats() + + def _safe_key(self, key: str) -> str: + """Convert key to safe filename.""" + import hashlib + return hashlib.md5(key.encode()).hexdigest() + + def _ensure_capacity(self): + """Ensure cache has capacity for new entry.""" + cache_files = list(self.cache_dir.glob("*.cache")) + if len(cache_files) >= self.max_files: + # Remove oldest file + oldest_file = min(cache_files, key=lambda f: f.stat().st_mtime) + size = oldest_file.stat().st_size + oldest_file.unlink() + self.stats.total_entries -= 1 + self.stats.total_size_bytes -= size + self.stats.evictions += 1 + + def _load_stats(self): + """Load statistics from file.""" + stats_file = self.cache_dir / "stats.json" + if stats_file.exists(): + try: + with open(stats_file, 'r') as f: + data = json.load(f) + self.stats = CacheStats(**data) + except Exception: + pass + + def _save_stats(self): + """Save statistics to file.""" + stats_file = self.cache_dir / "stats.json" + try: + with open(stats_file, 'w') as f: + json.dump(asdict(self.stats), f, default=str) + except Exception: + pass + + +class CacheManager: + """Cache manager with multiple backends and intelligent caching strategies.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize cache manager. + + Args: + config: Cache configuration + """ + self.config = config + self.backends: Dict[str, CacheBackend] = {} + self.default_backend = config.get('default_backend', 'memory') + self.default_ttl = config.get('default_ttl_seconds', 3600) # 1 hour + + # Initialize backends + self._init_backends() + + # Start cleanup task + self.cleanup_interval = config.get('cleanup_interval_seconds', 300) # 5 minutes + self._start_cleanup_task() + + def _init_backends(self): + """Initialize cache backends.""" + backends_config = self.config.get('backends', {}) + + # Memory backend + if 'memory' in backends_config or self.default_backend == 'memory': + memory_config = backends_config.get('memory', {}) + self.backends['memory'] = InMemoryCache( + max_size=memory_config.get('max_size', 1000), + max_size_bytes=memory_config.get('max_size_bytes', 100 * 1024 * 1024) + ) + + # File backend + if 'file' in backends_config or self.default_backend == 'file': + file_config = backends_config.get('file', {}) + self.backends['file'] = FileCache( + cache_dir=file_config.get('cache_dir', './cache'), + max_files=file_config.get('max_files', 10000) + ) + + def get(self, key: str, backend: Optional[str] = None) -> Optional[Any]: + """Get value from cache. + + Args: + key: Cache key + backend: Backend name (optional) + + Returns: + Cached value or None + """ + backend_name = backend or self.default_backend + cache_backend = self.backends.get(backend_name) + + if cache_backend is None: + logger.warning(f"Cache backend '{backend_name}' not found") + return None + + entry = cache_backend.get(key) + return entry.value if entry else None + + def set(self, + key: str, + value: Any, + ttl_seconds: Optional[int] = None, + tags: Optional[List[str]] = None, + backend: Optional[str] = None): + """Set value in cache. + + Args: + key: Cache key + value: Value to cache + ttl_seconds: Time to live in seconds + tags: Cache tags + backend: Backend name (optional) + """ + backend_name = backend or self.default_backend + cache_backend = self.backends.get(backend_name) + + if cache_backend is None: + logger.warning(f"Cache backend '{backend_name}' not found") + return + + ttl = ttl_seconds if ttl_seconds is not None else self.default_ttl + cache_backend.set(key, value, ttl, tags) + + def delete(self, key: str, backend: Optional[str] = None) -> bool: + """Delete value from cache. + + Args: + key: Cache key + backend: Backend name (optional) + + Returns: + True if deleted + """ + backend_name = backend or self.default_backend + cache_backend = self.backends.get(backend_name) + + if cache_backend is None: + return False + + return cache_backend.delete(key) + + def clear(self, tags: Optional[List[str]] = None, backend: Optional[str] = None): + """Clear cache entries. + + Args: + tags: Tags to clear (optional) + backend: Backend name (optional) + """ + if backend: + cache_backend = self.backends.get(backend) + if cache_backend: + cache_backend.clear(tags) + else: + # Clear all backends + for cache_backend in self.backends.values(): + cache_backend.clear(tags) + + def get_stats(self) -> Dict[str, CacheStats]: + """Get statistics for all backends. + + Returns: + Dictionary of backend name to statistics + """ + return {name: backend.get_stats() for name, backend in self.backends.items()} + + def cleanup_expired(self): + """Clean up expired entries in all backends.""" + for backend in self.backends.values(): + try: + backend.cleanup_expired() + except Exception as e: + logger.error(f"Error cleaning up cache backend: {e}") + + def _start_cleanup_task(self): + """Start periodic cleanup task.""" + def cleanup_worker(): + while True: + try: + time.sleep(self.cleanup_interval) + self.cleanup_expired() + except Exception as e: + logger.error(f"Cache cleanup error: {e}") + + cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) + cleanup_thread.start() + + +# Cache decorators and utilities + +def cache_result(cache_manager: CacheManager, + key_func: Optional[callable] = None, + ttl_seconds: Optional[int] = None, + tags: Optional[List[str]] = None, + backend: Optional[str] = None): + """Decorator to cache function results. + + Args: + cache_manager: Cache manager instance + key_func: Function to generate cache key + ttl_seconds: Time to live + tags: Cache tags + backend: Backend name + """ + def decorator(func): + def wrapper(*args, **kwargs): + # Generate cache key + if key_func: + cache_key = key_func(*args, **kwargs) + else: + cache_key = f"{func.__name__}:{hash((args, tuple(sorted(kwargs.items()))))}" + + # Try to get from cache + cached_result = cache_manager.get(cache_key, backend) + if cached_result is not None: + return cached_result + + # Execute function + result = func(*args, **kwargs) + + # Cache result + cache_manager.set(cache_key, result, ttl_seconds, tags, backend) + + return result + + return wrapper + return decorator \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/cypher_executor.py b/trustgraph-flow/trustgraph/query/ontology/cypher_executor.py new file mode 100644 index 00000000..56e4c829 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/cypher_executor.py @@ -0,0 +1,610 @@ +""" +Cypher executor for multiple graph databases. +Executes Cypher queries against Neo4j, Memgraph, and FalkorDB. +""" + +import logging +import asyncio +from typing import Dict, Any, List, Optional, Union +from dataclasses import dataclass +from abc import ABC, abstractmethod + +from .cypher_generator import CypherQuery + +logger = logging.getLogger(__name__) + +# Try to import various database drivers +try: + from neo4j import GraphDatabase, Driver as Neo4jDriver + NEO4J_AVAILABLE = True +except ImportError: + NEO4J_AVAILABLE = False + Neo4jDriver = None + +try: + import redis + REDIS_AVAILABLE = True +except ImportError: + REDIS_AVAILABLE = False + + +@dataclass +class CypherResult: + """Result from Cypher query execution.""" + records: List[Dict[str, Any]] + summary: Dict[str, Any] + execution_time: float + database_type: str + query_plan: Optional[Dict[str, Any]] = None + + +class CypherExecutorBase(ABC): + """Abstract base class for Cypher executors.""" + + @abstractmethod + async def execute(self, cypher_query: CypherQuery) -> CypherResult: + """Execute Cypher query.""" + pass + + @abstractmethod + async def close(self): + """Close database connection.""" + pass + + @abstractmethod + def is_connected(self) -> bool: + """Check if connected to database.""" + pass + + +class Neo4jExecutor(CypherExecutorBase): + """Cypher executor for Neo4j database.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize Neo4j executor. + + Args: + config: Neo4j configuration + """ + if not NEO4J_AVAILABLE: + raise RuntimeError("Neo4j driver not available") + + self.config = config + self.driver: Optional[Neo4jDriver] = None + self._connection_pool_size = config.get('connection_pool_size', 10) + + async def connect(self): + """Connect to Neo4j database.""" + try: + uri = self.config.get('uri', 'bolt://localhost:7687') + username = self.config.get('username') + password = self.config.get('password') + + auth = (username, password) if username and password else None + + # Create driver with connection pool + self.driver = GraphDatabase.driver( + uri, + auth=auth, + max_connection_pool_size=self._connection_pool_size, + connection_timeout=self.config.get('connection_timeout', 30), + max_retry_time=self.config.get('max_retry_time', 15) + ) + + # Verify connectivity + await asyncio.get_event_loop().run_in_executor( + None, self.driver.verify_connectivity + ) + + logger.info(f"Connected to Neo4j at {uri}") + + except Exception as e: + logger.error(f"Failed to connect to Neo4j: {e}") + raise + + async def execute(self, cypher_query: CypherQuery) -> CypherResult: + """Execute Cypher query against Neo4j. + + Args: + cypher_query: Cypher query to execute + + Returns: + Query results + """ + if not self.driver: + await self.connect() + + import time + start_time = time.time() + + try: + # Execute query in a session + records = await asyncio.get_event_loop().run_in_executor( + None, self._execute_sync, cypher_query + ) + + execution_time = time.time() - start_time + + return CypherResult( + records=records, + summary={'record_count': len(records)}, + execution_time=execution_time, + database_type='neo4j' + ) + + except Exception as e: + logger.error(f"Neo4j query execution error: {e}") + execution_time = time.time() - start_time + return CypherResult( + records=[], + summary={'error': str(e)}, + execution_time=execution_time, + database_type='neo4j' + ) + + def _execute_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]: + """Execute query synchronously in thread executor. + + Args: + cypher_query: Cypher query to execute + + Returns: + List of record dictionaries + """ + with self.driver.session() as session: + result = session.run(cypher_query.query, cypher_query.parameters) + records = [] + for record in result: + record_dict = {} + for key in record.keys(): + value = record[key] + record_dict[key] = self._format_neo4j_value(value) + records.append(record_dict) + return records + + def _format_neo4j_value(self, value): + """Format Neo4j value for JSON serialization. + + Args: + value: Neo4j value + + Returns: + JSON-serializable value + """ + # Handle Neo4j node objects + if hasattr(value, 'labels') and hasattr(value, 'items'): + return { + 'labels': list(value.labels), + 'properties': dict(value.items()) + } + # Handle Neo4j relationship objects + elif hasattr(value, 'type') and hasattr(value, 'items'): + return { + 'type': value.type, + 'properties': dict(value.items()) + } + # Handle Neo4j path objects + elif hasattr(value, 'nodes') and hasattr(value, 'relationships'): + return { + 'nodes': [self._format_neo4j_value(n) for n in value.nodes], + 'relationships': [self._format_neo4j_value(r) for r in value.relationships] + } + else: + return value + + async def close(self): + """Close Neo4j connection.""" + if self.driver: + await asyncio.get_event_loop().run_in_executor( + None, self.driver.close + ) + self.driver = None + logger.info("Neo4j connection closed") + + def is_connected(self) -> bool: + """Check if connected to Neo4j.""" + return self.driver is not None + + +class MemgraphExecutor(CypherExecutorBase): + """Cypher executor for Memgraph database.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize Memgraph executor. + + Args: + config: Memgraph configuration + """ + if not NEO4J_AVAILABLE: # Memgraph uses Neo4j driver + raise RuntimeError("Neo4j driver required for Memgraph") + + self.config = config + self.driver: Optional[Neo4jDriver] = None + + async def connect(self): + """Connect to Memgraph database.""" + try: + uri = self.config.get('uri', 'bolt://localhost:7688') + username = self.config.get('username') + password = self.config.get('password') + + auth = (username, password) if username and password else None + + # Memgraph uses Neo4j driver but with different defaults + self.driver = GraphDatabase.driver( + uri, + auth=auth, + max_connection_pool_size=self.config.get('connection_pool_size', 5), + connection_timeout=self.config.get('connection_timeout', 10) + ) + + # Verify connectivity + await asyncio.get_event_loop().run_in_executor( + None, self.driver.verify_connectivity + ) + + logger.info(f"Connected to Memgraph at {uri}") + + except Exception as e: + logger.error(f"Failed to connect to Memgraph: {e}") + raise + + async def execute(self, cypher_query: CypherQuery) -> CypherResult: + """Execute Cypher query against Memgraph. + + Args: + cypher_query: Cypher query to execute + + Returns: + Query results + """ + if not self.driver: + await self.connect() + + import time + start_time = time.time() + + try: + # Execute query with Memgraph-specific optimizations + records = await asyncio.get_event_loop().run_in_executor( + None, self._execute_memgraph_sync, cypher_query + ) + + execution_time = time.time() - start_time + + return CypherResult( + records=records, + summary={ + 'record_count': len(records), + 'engine': 'memgraph' + }, + execution_time=execution_time, + database_type='memgraph' + ) + + except Exception as e: + logger.error(f"Memgraph query execution error: {e}") + execution_time = time.time() - start_time + return CypherResult( + records=[], + summary={'error': str(e)}, + execution_time=execution_time, + database_type='memgraph' + ) + + def _execute_memgraph_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]: + """Execute query synchronously for Memgraph. + + Args: + cypher_query: Cypher query to execute + + Returns: + List of record dictionaries + """ + with self.driver.session() as session: + # Add Memgraph-specific query hints if available + query = cypher_query.query + if cypher_query.database_hints and cypher_query.database_hints.get('memory_limit'): + # Memgraph supports memory limits + query = f"// Memory limit: {cypher_query.database_hints['memory_limit']}\n{query}" + + result = session.run(query, cypher_query.parameters) + records = [] + for record in result: + record_dict = {} + for key in record.keys(): + record_dict[key] = record[key] + records.append(record_dict) + return records + + async def close(self): + """Close Memgraph connection.""" + if self.driver: + await asyncio.get_event_loop().run_in_executor( + None, self.driver.close + ) + self.driver = None + logger.info("Memgraph connection closed") + + def is_connected(self) -> bool: + """Check if connected to Memgraph.""" + return self.driver is not None + + +class FalkorDBExecutor(CypherExecutorBase): + """Cypher executor for FalkorDB (Redis-based graph database).""" + + def __init__(self, config: Dict[str, Any]): + """Initialize FalkorDB executor. + + Args: + config: FalkorDB configuration + """ + if not REDIS_AVAILABLE: + raise RuntimeError("Redis driver required for FalkorDB") + + self.config = config + self.redis_client: Optional[redis.Redis] = None + self.graph_name = config.get('graph_name', 'knowledge_graph') + + async def connect(self): + """Connect to FalkorDB (Redis).""" + try: + self.redis_client = redis.Redis( + host=self.config.get('host', 'localhost'), + port=self.config.get('port', 6379), + password=self.config.get('password'), + db=self.config.get('db', 0), + decode_responses=True, + socket_connect_timeout=self.config.get('connection_timeout', 10), + socket_timeout=self.config.get('socket_timeout', 10) + ) + + # Test connection + await asyncio.get_event_loop().run_in_executor( + None, self.redis_client.ping + ) + + logger.info(f"Connected to FalkorDB at {self.config.get('host', 'localhost')}") + + except Exception as e: + logger.error(f"Failed to connect to FalkorDB: {e}") + raise + + async def execute(self, cypher_query: CypherQuery) -> CypherResult: + """Execute Cypher query against FalkorDB. + + Args: + cypher_query: Cypher query to execute + + Returns: + Query results + """ + if not self.redis_client: + await self.connect() + + import time + start_time = time.time() + + try: + # Execute query using FalkorDB's GRAPH.QUERY command + records = await asyncio.get_event_loop().run_in_executor( + None, self._execute_falkordb_sync, cypher_query + ) + + execution_time = time.time() - start_time + + return CypherResult( + records=records, + summary={ + 'record_count': len(records), + 'engine': 'falkordb' + }, + execution_time=execution_time, + database_type='falkordb' + ) + + except Exception as e: + logger.error(f"FalkorDB query execution error: {e}") + execution_time = time.time() - start_time + return CypherResult( + records=[], + summary={'error': str(e)}, + execution_time=execution_time, + database_type='falkordb' + ) + + def _execute_falkordb_sync(self, cypher_query: CypherQuery) -> List[Dict[str, Any]]: + """Execute query synchronously for FalkorDB. + + Args: + cypher_query: Cypher query to execute + + Returns: + List of record dictionaries + """ + # Substitute parameters in query (FalkorDB parameter handling) + query = cypher_query.query + for param, value in cypher_query.parameters.items(): + if isinstance(value, str): + query = query.replace(f'${param}', f'"{value}"') + else: + query = query.replace(f'${param}', str(value)) + + # Execute using FalkorDB GRAPH.QUERY command + result = self.redis_client.execute_command( + 'GRAPH.QUERY', self.graph_name, query + ) + + # Parse FalkorDB result format + records = [] + if result and len(result) > 1: + # FalkorDB returns [header, data rows, statistics] + headers = result[0] if result[0] else [] + data_rows = result[1] if len(result) > 1 else [] + + for row in data_rows: + record = {} + for i, header in enumerate(headers): + if i < len(row): + record[header] = self._format_falkordb_value(row[i]) + records.append(record) + + return records + + def _format_falkordb_value(self, value): + """Format FalkorDB value for JSON serialization. + + Args: + value: FalkorDB value + + Returns: + JSON-serializable value + """ + # FalkorDB returns values in specific formats + if isinstance(value, list) and len(value) == 3: + # Check if it's a node/relationship representation + if value[0] == 1: # Node + return { + 'type': 'node', + 'labels': value[1], + 'properties': value[2] + } + elif value[0] == 2: # Relationship + return { + 'type': 'relationship', + 'rel_type': value[1], + 'properties': value[2] + } + + return value + + async def close(self): + """Close FalkorDB connection.""" + if self.redis_client: + await asyncio.get_event_loop().run_in_executor( + None, self.redis_client.close + ) + self.redis_client = None + logger.info("FalkorDB connection closed") + + def is_connected(self) -> bool: + """Check if connected to FalkorDB.""" + return self.redis_client is not None + + +class CypherExecutor: + """Multi-database Cypher executor with automatic routing.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize multi-database executor. + + Args: + config: Configuration for all database types + """ + self.config = config + self.executors: Dict[str, CypherExecutorBase] = {} + + # Initialize available executors + self._initialize_executors() + + def _initialize_executors(self): + """Initialize database executors based on configuration.""" + # Neo4j executor + if 'neo4j' in self.config and NEO4J_AVAILABLE: + try: + self.executors['neo4j'] = Neo4jExecutor(self.config['neo4j']) + logger.info("Neo4j executor initialized") + except Exception as e: + logger.error(f"Failed to initialize Neo4j executor: {e}") + + # Memgraph executor + if 'memgraph' in self.config and NEO4J_AVAILABLE: + try: + self.executors['memgraph'] = MemgraphExecutor(self.config['memgraph']) + logger.info("Memgraph executor initialized") + except Exception as e: + logger.error(f"Failed to initialize Memgraph executor: {e}") + + # FalkorDB executor + if 'falkordb' in self.config and REDIS_AVAILABLE: + try: + self.executors['falkordb'] = FalkorDBExecutor(self.config['falkordb']) + logger.info("FalkorDB executor initialized") + except Exception as e: + logger.error(f"Failed to initialize FalkorDB executor: {e}") + + if not self.executors: + raise RuntimeError("No database executors could be initialized") + + async def execute_cypher(self, cypher_query: CypherQuery, + database_type: str) -> CypherResult: + """Execute Cypher query on specified database. + + Args: + cypher_query: Cypher query to execute + database_type: Target database type + + Returns: + Query results + """ + if database_type not in self.executors: + raise ValueError(f"Database type {database_type} not available. " + f"Available: {list(self.executors.keys())}") + + executor = self.executors[database_type] + + # Ensure connection + if not executor.is_connected(): + await executor.connect() + + # Execute query + return await executor.execute(cypher_query) + + async def execute_on_all(self, cypher_query: CypherQuery) -> Dict[str, CypherResult]: + """Execute query on all available databases. + + Args: + cypher_query: Cypher query to execute + + Returns: + Results from all databases + """ + results = {} + tasks = [] + + for db_type, executor in self.executors.items(): + task = asyncio.create_task( + self.execute_cypher(cypher_query, db_type), + name=f"cypher_query_{db_type}" + ) + tasks.append((db_type, task)) + + # Wait for all tasks to complete + for db_type, task in tasks: + try: + results[db_type] = await task + except Exception as e: + logger.error(f"Query failed on {db_type}: {e}") + results[db_type] = CypherResult( + records=[], + summary={'error': str(e)}, + execution_time=0.0, + database_type=db_type + ) + + return results + + def get_available_databases(self) -> List[str]: + """Get list of available database types. + + Returns: + List of available database type names + """ + return list(self.executors.keys()) + + async def close_all(self): + """Close all database connections.""" + for executor in self.executors.values(): + await executor.close() + logger.info("All Cypher executor connections closed") \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/cypher_generator.py b/trustgraph-flow/trustgraph/query/ontology/cypher_generator.py new file mode 100644 index 00000000..8c43e964 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/cypher_generator.py @@ -0,0 +1,628 @@ +""" +Cypher query generator for ontology-sensitive queries. +Converts natural language questions to Cypher queries for graph databases. +""" + +import logging +from typing import Dict, Any, List, Optional +from dataclasses import dataclass + +from .question_analyzer import QuestionComponents, QuestionType +from .ontology_matcher import QueryOntologySubset + +logger = logging.getLogger(__name__) + + +@dataclass +class CypherQuery: + """Generated Cypher query with metadata.""" + query: str + parameters: Dict[str, Any] + variables: List[str] + explanation: str + complexity_score: float + database_hints: Dict[str, Any] = None # Database-specific optimization hints + + +class CypherGenerator: + """Generates Cypher queries from natural language questions using LLM assistance.""" + + def __init__(self, prompt_service=None): + """Initialize Cypher generator. + + Args: + prompt_service: Service for LLM-based query generation + """ + self.prompt_service = prompt_service + + # Cypher query templates for common patterns + self.templates = { + 'simple_node_query': """ +MATCH (n:{node_label}) +RETURN n.name AS name, n.{property} AS {property} +LIMIT {limit}""", + + 'relationship_query': """ +MATCH (a:{source_label})-[r:{relationship}]->(b:{target_label}) +WHERE a.name = $source_name +RETURN b.name AS name, r.{rel_property} AS property""", + + 'path_query': """ +MATCH path = (start:{start_label})-[*1..{max_depth}]->(end:{end_label}) +WHERE start.name = $start_name +RETURN path, length(path) AS path_length +ORDER BY path_length""", + + 'count_query': """ +MATCH (n:{node_label}) +{where_clause} +RETURN count(n) AS count""", + + 'aggregation_query': """ +MATCH (n:{node_label}) +{where_clause} +RETURN + count(n) AS count, + avg(n.{numeric_property}) AS average, + sum(n.{numeric_property}) AS total""", + + 'boolean_query': """ +MATCH (a:{source_label})-[:{relationship}]->(b:{target_label}) +WHERE a.name = $source_name AND b.name = $target_name +RETURN count(*) > 0 AS exists""", + + 'hierarchy_query': """ +MATCH (child:{child_label})-[:SUBCLASS_OF*]->(parent:{parent_label}) +WHERE parent.name = $parent_name +RETURN child.name AS child_name, parent.name AS parent_name""", + + 'property_filter_query': """ +MATCH (n:{node_label}) +WHERE n.{property} {operator} ${property}_value +RETURN n.name AS name, n.{property} AS {property} +ORDER BY n.{property}""" + } + + async def generate_cypher(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + database_type: str = "neo4j") -> CypherQuery: + """Generate Cypher query for a question. + + Args: + question_components: Analyzed question components + ontology_subset: Relevant ontology subset + database_type: Target database (neo4j, memgraph, falkordb) + + Returns: + Generated Cypher query + """ + # Try template-based generation first + template_query = self._try_template_generation( + question_components, ontology_subset, database_type + ) + if template_query: + logger.debug("Generated Cypher using template") + return template_query + + # Fall back to LLM-based generation + if self.prompt_service: + llm_query = await self._generate_with_llm( + question_components, ontology_subset, database_type + ) + if llm_query: + logger.debug("Generated Cypher using LLM") + return llm_query + + # Final fallback to simple pattern + logger.warning("Falling back to simple Cypher pattern") + return self._generate_fallback_query(question_components, ontology_subset) + + def _try_template_generation(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + database_type: str) -> Optional[CypherQuery]: + """Try to generate query using templates. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + database_type: Target database type + + Returns: + Generated query or None if no template matches + """ + # Simple node query (What are the animals?) + if (question_components.question_type == QuestionType.RETRIEVAL and + len(question_components.entities) == 1): + + node_label = self._find_matching_node_label( + question_components.entities[0], ontology_subset + ) + if node_label: + query = self.templates['simple_node_query'].format( + node_label=node_label, + property='name', + limit=100 + ) + return CypherQuery( + query=query, + parameters={}, + variables=['name'], + explanation=f"Retrieve all nodes of type {node_label}", + complexity_score=0.2, + database_hints=self._get_database_hints(database_type, 'simple') + ) + + # Count query (How many animals are there?) + if (question_components.question_type == QuestionType.AGGREGATION and + 'count' in question_components.aggregations): + + node_label = self._find_matching_node_label( + question_components.entities[0] if question_components.entities else 'Entity', + ontology_subset + ) + if node_label: + where_clause = self._build_where_clause(question_components) + query = self.templates['count_query'].format( + node_label=node_label, + where_clause=where_clause + ) + return CypherQuery( + query=query, + parameters=self._extract_parameters(question_components), + variables=['count'], + explanation=f"Count nodes of type {node_label}", + complexity_score=0.3, + database_hints=self._get_database_hints(database_type, 'aggregation') + ) + + # Relationship query (Which documents were authored by John Smith?) + if (question_components.question_type == QuestionType.RETRIEVAL and + len(question_components.entities) >= 2): + + source_label = self._find_matching_node_label( + question_components.entities[1], ontology_subset + ) + target_label = self._find_matching_node_label( + question_components.entities[0], ontology_subset + ) + relationship = self._find_matching_relationship( + question_components, ontology_subset + ) + + if source_label and target_label and relationship: + query = self.templates['relationship_query'].format( + source_label=source_label, + target_label=target_label, + relationship=relationship, + rel_property='name' + ) + return CypherQuery( + query=query, + parameters={'source_name': question_components.entities[1]}, + variables=['name'], + explanation=f"Find {target_label} related to {source_label} via {relationship}", + complexity_score=0.4, + database_hints=self._get_database_hints(database_type, 'relationship') + ) + + # Boolean query (Is X related to Y?) + if question_components.question_type == QuestionType.BOOLEAN: + if len(question_components.entities) >= 2: + source_label = self._find_matching_node_label( + question_components.entities[0], ontology_subset + ) + target_label = self._find_matching_node_label( + question_components.entities[1], ontology_subset + ) + relationship = self._find_matching_relationship( + question_components, ontology_subset + ) + + if source_label and target_label and relationship: + query = self.templates['boolean_query'].format( + source_label=source_label, + target_label=target_label, + relationship=relationship + ) + return CypherQuery( + query=query, + parameters={ + 'source_name': question_components.entities[0], + 'target_name': question_components.entities[1] + }, + variables=['exists'], + explanation="Boolean check for relationship existence", + complexity_score=0.3, + database_hints=self._get_database_hints(database_type, 'boolean') + ) + + return None + + async def _generate_with_llm(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + database_type: str) -> Optional[CypherQuery]: + """Generate Cypher using LLM. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + database_type: Target database type + + Returns: + Generated query or None if failed + """ + try: + prompt = self._build_cypher_prompt( + question_components, ontology_subset, database_type + ) + response = await self.prompt_service.generate_cypher(prompt=prompt) + + if response and isinstance(response, dict): + query = response.get('query', '').strip() + if query.upper().startswith(('MATCH', 'CREATE', 'MERGE', 'DELETE', 'RETURN')): + return CypherQuery( + query=query, + parameters=response.get('parameters', {}), + variables=self._extract_variables(query), + explanation=response.get('explanation', 'Generated by LLM'), + complexity_score=self._calculate_complexity(query), + database_hints=self._get_database_hints(database_type, 'complex') + ) + + except Exception as e: + logger.error(f"LLM Cypher generation failed: {e}") + + return None + + def _build_cypher_prompt(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + database_type: str) -> str: + """Build prompt for LLM Cypher generation. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + database_type: Target database type + + Returns: + Formatted prompt string + """ + # Format ontology elements as node labels and relationships + node_labels = self._format_node_labels(ontology_subset.classes) + relationships = self._format_relationships( + ontology_subset.object_properties, + ontology_subset.datatype_properties + ) + + prompt = f"""Generate a Cypher query for the following question using the provided ontology. + +QUESTION: {question_components.original_question} + +TARGET DATABASE: {database_type} + +AVAILABLE NODE LABELS (from classes): +{node_labels} + +AVAILABLE RELATIONSHIP TYPES (from properties): +{relationships} + +RULES: +- Use MATCH patterns for graph traversal +- Include WHERE clauses for filters +- Use aggregation functions when needed (COUNT, SUM, AVG) +- Optimize for {database_type} performance +- Consider index hints for large datasets +- Use parameters for values (e.g., $name) + +QUERY TYPE HINTS: +- Question type: {question_components.question_type.value} +- Expected answer: {question_components.expected_answer_type} +- Entities mentioned: {', '.join(question_components.entities)} +- Aggregations: {', '.join(question_components.aggregations)} + +DATABASE-SPECIFIC OPTIMIZATIONS: +{self._get_database_specific_hints(database_type)} + +Generate a complete Cypher query with parameters:""" + + return prompt + + def _generate_fallback_query(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> CypherQuery: + """Generate simple fallback query. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Basic Cypher query + """ + # Very basic MATCH query + first_class = list(ontology_subset.classes.keys())[0] if ontology_subset.classes else 'Entity' + + query = f"""MATCH (n:{first_class}) +WHERE n.name CONTAINS $keyword +RETURN n.name AS name, labels(n) AS types +LIMIT 10""" + + return CypherQuery( + query=query, + parameters={'keyword': question_components.keywords[0] if question_components.keywords else 'entity'}, + variables=['name', 'types'], + explanation="Fallback query for basic pattern matching", + complexity_score=0.1 + ) + + def _find_matching_node_label(self, entity: str, ontology_subset: QueryOntologySubset) -> Optional[str]: + """Find matching node label in ontology subset. + + Args: + entity: Entity string to match + ontology_subset: Ontology subset + + Returns: + Matching node label or None + """ + entity_lower = entity.lower() + + # Direct match + for class_id in ontology_subset.classes: + if class_id.lower() == entity_lower: + return class_id + + # Label match + for class_id, class_def in ontology_subset.classes.items(): + labels = class_def.get('labels', []) + for label in labels: + if isinstance(label, dict): + label_value = label.get('value', '').lower() + if label_value == entity_lower: + return class_id + + # Partial match + for class_id in ontology_subset.classes: + if entity_lower in class_id.lower() or class_id.lower() in entity_lower: + return class_id + + return None + + def _find_matching_relationship(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> Optional[str]: + """Find matching relationship type. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Matching relationship type or None + """ + # Look for relationship keywords + for keyword in question_components.keywords: + keyword_lower = keyword.lower() + + # Check object properties + for prop_id in ontology_subset.object_properties: + if keyword_lower in prop_id.lower() or prop_id.lower() in keyword_lower: + return prop_id.upper().replace('-', '_') + + # Common relationship mappings + relationship_mappings = { + 'author': 'AUTHORED_BY', + 'created': 'CREATED_BY', + 'owns': 'OWNS', + 'has': 'HAS', + 'contains': 'CONTAINS', + 'parent': 'PARENT_OF', + 'child': 'CHILD_OF', + 'related': 'RELATED_TO' + } + + for keyword in question_components.keywords: + if keyword.lower() in relationship_mappings: + return relationship_mappings[keyword.lower()] + + # Default relationship + return 'RELATED_TO' + + def _build_where_clause(self, question_components: QuestionComponents) -> str: + """Build WHERE clause for Cypher query. + + Args: + question_components: Question analysis + + Returns: + WHERE clause string + """ + conditions = [] + + for constraint in question_components.constraints: + if 'greater than' in constraint.lower(): + import re + numbers = re.findall(r'\d+', constraint) + if numbers: + conditions.append(f"n.value > {numbers[0]}") + elif 'less than' in constraint.lower(): + numbers = re.findall(r'\d+', constraint) + if numbers: + conditions.append(f"n.value < {numbers[0]}") + + if conditions: + return f"WHERE {' AND '.join(conditions)}" + return "" + + def _extract_parameters(self, question_components: QuestionComponents) -> Dict[str, Any]: + """Extract parameters from question components. + + Args: + question_components: Question analysis + + Returns: + Parameters dictionary + """ + parameters = {} + + # Extract numeric values + import re + for constraint in question_components.constraints: + numbers = re.findall(r'\d+', constraint) + for i, number in enumerate(numbers): + parameters[f'value_{i}'] = int(number) + + return parameters + + def _format_node_labels(self, classes: Dict[str, Any]) -> str: + """Format classes as node labels for prompt. + + Args: + classes: Classes dictionary + + Returns: + Formatted node labels string + """ + if not classes: + return "None" + + lines = [] + for class_id, definition in classes.items(): + comment = definition.get('comment', '') + lines.append(f"- :{class_id} - {comment}") + + return '\n'.join(lines) + + def _format_relationships(self, + object_props: Dict[str, Any], + datatype_props: Dict[str, Any]) -> str: + """Format properties as relationships for prompt. + + Args: + object_props: Object properties + datatype_props: Datatype properties + + Returns: + Formatted relationships string + """ + lines = [] + + for prop_id, definition in object_props.items(): + domain = definition.get('domain', 'Any') + range_val = definition.get('range', 'Any') + comment = definition.get('comment', '') + rel_type = prop_id.upper().replace('-', '_') + lines.append(f"- :{rel_type} ({domain} -> {range_val}) - {comment}") + + return '\n'.join(lines) if lines else "None" + + def _extract_variables(self, query: str) -> List[str]: + """Extract variables from Cypher query. + + Args: + query: Cypher query string + + Returns: + List of variable names + """ + import re + # Extract RETURN clause variables + return_match = re.search(r'RETURN\s+(.+?)(?:ORDER|LIMIT|$)', query, re.IGNORECASE | re.DOTALL) + if return_match: + return_clause = return_match.group(1) + variables = re.findall(r'(\w+)(?:\s+AS\s+(\w+))?', return_clause) + return [var[1] if var[1] else var[0] for var in variables] + return [] + + def _calculate_complexity(self, query: str) -> float: + """Calculate complexity score for Cypher query. + + Args: + query: Cypher query string + + Returns: + Complexity score (0.0 to 1.0) + """ + complexity = 0.0 + query_upper = query.upper() + + # Count different Cypher features + if 'JOIN' in query_upper or 'UNION' in query_upper: + complexity += 0.3 + if 'WHERE' in query_upper: + complexity += 0.2 + if 'OPTIONAL' in query_upper: + complexity += 0.1 + if 'ORDER BY' in query_upper: + complexity += 0.1 + if '*' in query: # Variable length paths + complexity += 0.2 + if any(agg in query_upper for agg in ['COUNT', 'SUM', 'AVG', 'MAX', 'MIN']): + complexity += 0.2 + + # Count path length + path_matches = re.findall(r'\[.*?\*(\d+)\.\.(\d+).*?\]', query) + for start, end in path_matches: + complexity += (int(end) - int(start)) * 0.05 + + return min(complexity, 1.0) + + def _get_database_hints(self, database_type: str, query_category: str) -> Dict[str, Any]: + """Get database-specific optimization hints. + + Args: + database_type: Target database + query_category: Category of query + + Returns: + Optimization hints + """ + hints = {} + + if database_type == "neo4j": + hints.update({ + 'use_index': True, + 'explain_plan': 'EXPLAIN', + 'profile_query': 'PROFILE' + }) + elif database_type == "memgraph": + hints.update({ + 'use_index': True, + 'explain_plan': 'EXPLAIN', + 'memory_limit': '1GB' + }) + elif database_type == "falkordb": + hints.update({ + 'use_index': False, # Redis-based, different indexing + 'cache_result': True + }) + + return hints + + def _get_database_specific_hints(self, database_type: str) -> str: + """Get database-specific optimization hints as text. + + Args: + database_type: Target database + + Returns: + Hints as formatted string + """ + if database_type == "neo4j": + return """- Use USING INDEX hints for large datasets +- Consider PROFILE for query optimization +- Prefer MERGE over CREATE when appropriate""" + elif database_type == "memgraph": + return """- Leverage in-memory processing advantages +- Use streaming for large result sets +- Consider query parallelization""" + elif database_type == "falkordb": + return """- Optimize for Redis memory constraints +- Use simple patterns for best performance +- Leverage Redis data structures when possible""" + else: + return "- Use standard Cypher optimization patterns" \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/error_handling.py b/trustgraph-flow/trustgraph/query/ontology/error_handling.py new file mode 100644 index 00000000..cc047787 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/error_handling.py @@ -0,0 +1,557 @@ +""" +Error handling and recovery mechanisms for OntoRAG. +Provides comprehensive error handling, retry logic, and graceful degradation. +""" + +import logging +import time +import asyncio +from typing import Dict, Any, List, Optional, Callable, Union, Type +from dataclasses import dataclass +from enum import Enum +from functools import wraps +import traceback + +logger = logging.getLogger(__name__) + + +class ErrorSeverity(Enum): + """Error severity levels.""" + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + CRITICAL = "critical" + + +class ErrorCategory(Enum): + """Error categories for better handling.""" + ONTOLOGY_LOADING = "ontology_loading" + QUESTION_ANALYSIS = "question_analysis" + QUERY_GENERATION = "query_generation" + QUERY_EXECUTION = "query_execution" + ANSWER_GENERATION = "answer_generation" + BACKEND_CONNECTION = "backend_connection" + CACHE_ERROR = "cache_error" + VALIDATION_ERROR = "validation_error" + TIMEOUT_ERROR = "timeout_error" + AUTHENTICATION_ERROR = "authentication_error" + + +@dataclass +class ErrorContext: + """Context information for an error.""" + category: ErrorCategory + severity: ErrorSeverity + component: str + operation: str + user_message: Optional[str] = None + technical_details: Optional[str] = None + suggestion: Optional[str] = None + retry_count: int = 0 + max_retries: int = 3 + metadata: Dict[str, Any] = None + + +class OntoRAGError(Exception): + """Base exception for OntoRAG system.""" + + def __init__(self, + message: str, + context: Optional[ErrorContext] = None, + cause: Optional[Exception] = None): + """Initialize OntoRAG error. + + Args: + message: Error message + context: Error context + cause: Original exception that caused this error + """ + super().__init__(message) + self.message = message + self.context = context or ErrorContext( + category=ErrorCategory.VALIDATION_ERROR, + severity=ErrorSeverity.MEDIUM, + component="unknown", + operation="unknown" + ) + self.cause = cause + self.timestamp = time.time() + + +class OntologyLoadingError(OntoRAGError): + """Error loading ontology.""" + pass + + +class QuestionAnalysisError(OntoRAGError): + """Error analyzing question.""" + pass + + +class QueryGenerationError(OntoRAGError): + """Error generating query.""" + pass + + +class QueryExecutionError(OntoRAGError): + """Error executing query.""" + pass + + +class AnswerGenerationError(OntoRAGError): + """Error generating answer.""" + pass + + +class BackendConnectionError(OntoRAGError): + """Error connecting to backend.""" + pass + + +class TimeoutError(OntoRAGError): + """Operation timeout error.""" + pass + + +@dataclass +class RetryConfig: + """Configuration for retry logic.""" + max_retries: int = 3 + base_delay: float = 1.0 + max_delay: float = 60.0 + exponential_backoff: bool = True + jitter: bool = True + retry_on_exceptions: List[Type[Exception]] = None + + +class ErrorRecoveryStrategy: + """Strategy for handling and recovering from errors.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize error recovery strategy. + + Args: + config: Recovery configuration + """ + self.config = config or {} + self.retry_configs = self._build_retry_configs() + self.fallback_strategies = self._build_fallback_strategies() + self.error_counters: Dict[str, int] = {} + self.circuit_breakers: Dict[str, Dict[str, Any]] = {} + + def _build_retry_configs(self) -> Dict[ErrorCategory, RetryConfig]: + """Build retry configurations for different error categories.""" + return { + ErrorCategory.BACKEND_CONNECTION: RetryConfig( + max_retries=5, + base_delay=2.0, + retry_on_exceptions=[BackendConnectionError, ConnectionError, TimeoutError] + ), + ErrorCategory.QUERY_EXECUTION: RetryConfig( + max_retries=3, + base_delay=1.0, + retry_on_exceptions=[QueryExecutionError, TimeoutError] + ), + ErrorCategory.ONTOLOGY_LOADING: RetryConfig( + max_retries=2, + base_delay=0.5, + retry_on_exceptions=[OntologyLoadingError, IOError] + ), + ErrorCategory.QUESTION_ANALYSIS: RetryConfig( + max_retries=2, + base_delay=1.0, + retry_on_exceptions=[QuestionAnalysisError, TimeoutError] + ), + ErrorCategory.ANSWER_GENERATION: RetryConfig( + max_retries=2, + base_delay=1.0, + retry_on_exceptions=[AnswerGenerationError, TimeoutError] + ) + } + + def _build_fallback_strategies(self) -> Dict[ErrorCategory, Callable]: + """Build fallback strategies for different error categories.""" + return { + ErrorCategory.QUESTION_ANALYSIS: self._fallback_question_analysis, + ErrorCategory.QUERY_GENERATION: self._fallback_query_generation, + ErrorCategory.QUERY_EXECUTION: self._fallback_query_execution, + ErrorCategory.ANSWER_GENERATION: self._fallback_answer_generation, + ErrorCategory.BACKEND_CONNECTION: self._fallback_backend_connection + } + + async def handle_error(self, + error: Exception, + context: ErrorContext, + operation: Callable, + *args, + **kwargs) -> Any: + """Handle error with recovery strategies. + + Args: + error: The exception that occurred + context: Error context + operation: Function to retry + *args: Operation arguments + **kwargs: Operation keyword arguments + + Returns: + Result of successful operation or fallback + """ + logger.error(f"Handling error in {context.component}.{context.operation}: {error}") + + # Update error counters + error_key = f"{context.category.value}:{context.component}" + self.error_counters[error_key] = self.error_counters.get(error_key, 0) + 1 + + # Check circuit breaker + if self._is_circuit_open(error_key): + return await self._execute_fallback(context, *args, **kwargs) + + # Try retry if configured + retry_config = self.retry_configs.get(context.category) + if retry_config and context.retry_count < retry_config.max_retries: + if any(isinstance(error, exc_type) for exc_type in retry_config.retry_on_exceptions or []): + return await self._retry_operation( + operation, context, retry_config, *args, **kwargs + ) + + # Execute fallback strategy + return await self._execute_fallback(context, *args, **kwargs) + + async def _retry_operation(self, + operation: Callable, + context: ErrorContext, + retry_config: RetryConfig, + *args, + **kwargs) -> Any: + """Retry operation with backoff.""" + context.retry_count += 1 + + # Calculate delay + delay = retry_config.base_delay + if retry_config.exponential_backoff: + delay *= (2 ** (context.retry_count - 1)) + delay = min(delay, retry_config.max_delay) + + # Add jitter + if retry_config.jitter: + import random + delay *= (0.5 + random.random()) + + logger.info(f"Retrying {context.component}.{context.operation} " + f"(attempt {context.retry_count}) after {delay:.2f}s") + + await asyncio.sleep(delay) + + try: + if asyncio.iscoroutinefunction(operation): + return await operation(*args, **kwargs) + else: + return operation(*args, **kwargs) + except Exception as e: + return await self.handle_error(e, context, operation, *args, **kwargs) + + async def _execute_fallback(self, + context: ErrorContext, + *args, + **kwargs) -> Any: + """Execute fallback strategy.""" + fallback_func = self.fallback_strategies.get(context.category) + if fallback_func: + logger.info(f"Executing fallback for {context.category.value}") + try: + if asyncio.iscoroutinefunction(fallback_func): + return await fallback_func(context, *args, **kwargs) + else: + return fallback_func(context, *args, **kwargs) + except Exception as e: + logger.error(f"Fallback strategy failed: {e}") + + # Default fallback + return self._default_fallback(context) + + def _is_circuit_open(self, error_key: str) -> bool: + """Check if circuit breaker is open.""" + circuit = self.circuit_breakers.get(error_key, {}) + error_count = self.error_counters.get(error_key, 0) + error_threshold = self.config.get('circuit_breaker_threshold', 10) + window_seconds = self.config.get('circuit_breaker_window', 300) # 5 minutes + + current_time = time.time() + window_start = circuit.get('window_start', current_time) + + # Reset window if expired + if current_time - window_start > window_seconds: + self.circuit_breakers[error_key] = {'window_start': current_time} + self.error_counters[error_key] = 0 + return False + + return error_count >= error_threshold + + def _default_fallback(self, context: ErrorContext) -> Any: + """Default fallback response.""" + if context.category == ErrorCategory.ANSWER_GENERATION: + return "I'm sorry, I encountered an error while processing your question. Please try again." + elif context.category == ErrorCategory.QUERY_EXECUTION: + return {"error": "Query execution failed", "results": []} + else: + return None + + # Fallback strategy implementations + + async def _fallback_question_analysis(self, context: ErrorContext, question: str, **kwargs): + """Fallback for question analysis.""" + from .question_analyzer import QuestionComponents, QuestionType + + # Simple keyword-based analysis + question_lower = question.lower() + + # Determine question type + if any(word in question_lower for word in ['how many', 'count', 'number']): + question_type = QuestionType.AGGREGATION + elif question_lower.startswith(('is', 'are', 'does', 'can')): + question_type = QuestionType.BOOLEAN + elif any(word in question_lower for word in ['what', 'which', 'who', 'where']): + question_type = QuestionType.RETRIEVAL + else: + question_type = QuestionType.FACTUAL + + # Extract simple entities (nouns) + import re + words = re.findall(r'\b[a-zA-Z]+\b', question) + entities = [word for word in words if len(word) > 3 and word.lower() not in + {'what', 'which', 'where', 'when', 'who', 'how', 'does', 'are', 'the'}] + + return QuestionComponents( + original_question=question, + normalized_question=question.lower(), + question_type=question_type, + entities=entities[:3], # Limit to 3 entities + keywords=words[:5], # Limit to 5 keywords + relationships=[], + constraints=[], + aggregations=['count'] if question_type == QuestionType.AGGREGATION else [], + expected_answer_type='text' + ) + + async def _fallback_query_generation(self, context: ErrorContext, **kwargs): + """Fallback for query generation.""" + # Generate simple query based on available information + if 'sparql' in context.metadata.get('query_language', '').lower(): + query = """ +PREFIX rdf: +PREFIX rdfs: + +SELECT ?subject ?predicate ?object WHERE { + ?subject ?predicate ?object . +} +LIMIT 10 +""" + from .sparql_generator import SPARQLQuery + return SPARQLQuery( + query=query, + variables=['subject', 'predicate', 'object'], + query_type='SELECT', + explanation='Fallback SPARQL query', + complexity_score=0.1 + ) + else: + query = "MATCH (n) RETURN n LIMIT 10" + from .cypher_generator import CypherQuery + return CypherQuery( + query=query, + variables=['n'], + query_type='MATCH', + explanation='Fallback Cypher query', + complexity_score=0.1 + ) + + async def _fallback_query_execution(self, context: ErrorContext, **kwargs): + """Fallback for query execution.""" + # Return empty results + if 'sparql' in context.metadata.get('query_language', '').lower(): + from .sparql_cassandra import SPARQLResult + return SPARQLResult( + bindings=[], + variables=[], + execution_time=0.0 + ) + else: + from .cypher_executor import CypherResult + return CypherResult( + records=[], + summary={'type': 'fallback'}, + metadata={'query': 'fallback'}, + execution_time=0.0 + ) + + async def _fallback_answer_generation(self, context: ErrorContext, question: str = None, **kwargs): + """Fallback for answer generation.""" + fallback_messages = [ + "I'm experiencing some technical difficulties. Please try rephrasing your question.", + "I couldn't process your question at the moment. Could you try asking it differently?", + "There seems to be an issue with my analysis. Please try again in a moment.", + "I'm having trouble understanding your question right now. Please try again." + ] + + import random + return random.choice(fallback_messages) + + async def _fallback_backend_connection(self, context: ErrorContext, **kwargs): + """Fallback for backend connection.""" + logger.warning(f"Backend connection failed for {context.component}") + # Could switch to alternative backend here + return None + + +def with_error_handling(category: ErrorCategory, + component: str, + operation: str, + severity: ErrorSeverity = ErrorSeverity.MEDIUM): + """Decorator for automatic error handling. + + Args: + category: Error category + component: Component name + operation: Operation name + severity: Error severity + """ + def decorator(func): + @wraps(func) + async def async_wrapper(*args, **kwargs): + try: + if asyncio.iscoroutinefunction(func): + return await func(*args, **kwargs) + else: + return func(*args, **kwargs) + except Exception as e: + context = ErrorContext( + category=category, + severity=severity, + component=component, + operation=operation, + technical_details=str(e), + metadata={'args': str(args), 'kwargs': str(kwargs)} + ) + + # Get error recovery strategy from first argument if it's available + error_strategy = None + if args and hasattr(args[0], '_error_strategy'): + error_strategy = args[0]._error_strategy + + if error_strategy: + return await error_strategy.handle_error(e, context, func, *args, **kwargs) + else: + # Re-raise as OntoRAG error + raise OntoRAGError( + f"Error in {component}.{operation}: {str(e)}", + context=context, + cause=e + ) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + context = ErrorContext( + category=category, + severity=severity, + component=component, + operation=operation, + technical_details=str(e), + metadata={'args': str(args), 'kwargs': str(kwargs)} + ) + + raise OntoRAGError( + f"Error in {component}.{operation}: {str(e)}", + context=context, + cause=e + ) + + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + return decorator + + +class ErrorReporter: + """Reports and tracks errors for monitoring and debugging.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize error reporter. + + Args: + config: Reporter configuration + """ + self.config = config or {} + self.error_log: List[Dict[str, Any]] = [] + self.max_log_size = self.config.get('max_log_size', 1000) + + def report_error(self, error: OntoRAGError): + """Report an error for tracking. + + Args: + error: The error to report + """ + error_entry = { + 'timestamp': error.timestamp, + 'message': error.message, + 'category': error.context.category.value, + 'severity': error.context.severity.value, + 'component': error.context.component, + 'operation': error.context.operation, + 'retry_count': error.context.retry_count, + 'technical_details': error.context.technical_details, + 'stack_trace': traceback.format_exc() if error.cause else None + } + + self.error_log.append(error_entry) + + # Trim log if too large + if len(self.error_log) > self.max_log_size: + self.error_log = self.error_log[-self.max_log_size:] + + # Log based on severity + if error.context.severity == ErrorSeverity.CRITICAL: + logger.critical(f"CRITICAL ERROR: {error.message}") + elif error.context.severity == ErrorSeverity.HIGH: + logger.error(f"HIGH SEVERITY: {error.message}") + elif error.context.severity == ErrorSeverity.MEDIUM: + logger.warning(f"MEDIUM SEVERITY: {error.message}") + else: + logger.info(f"LOW SEVERITY: {error.message}") + + def get_error_summary(self) -> Dict[str, Any]: + """Get summary of recent errors. + + Returns: + Error summary statistics + """ + if not self.error_log: + return {'total_errors': 0} + + recent_errors = [ + e for e in self.error_log + if time.time() - e['timestamp'] < 3600 # Last hour + ] + + category_counts = {} + severity_counts = {} + component_counts = {} + + for error in recent_errors: + category_counts[error['category']] = category_counts.get(error['category'], 0) + 1 + severity_counts[error['severity']] = severity_counts.get(error['severity'], 0) + 1 + component_counts[error['component']] = component_counts.get(error['component'], 0) + 1 + + return { + 'total_errors': len(self.error_log), + 'recent_errors': len(recent_errors), + 'category_breakdown': category_counts, + 'severity_breakdown': severity_counts, + 'component_breakdown': component_counts, + 'most_recent_error': self.error_log[-1] if self.error_log else None + } \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/monitoring.py b/trustgraph-flow/trustgraph/query/ontology/monitoring.py new file mode 100644 index 00000000..3eac4175 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/monitoring.py @@ -0,0 +1,737 @@ +""" +Performance monitoring and metrics collection for OntoRAG. +Provides comprehensive monitoring of system performance, query patterns, and resource usage. +""" + +import logging +import time +import asyncio +import threading +from typing import Dict, Any, List, Optional, Callable +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from collections import defaultdict, deque +import statistics +from enum import Enum + +logger = logging.getLogger(__name__) + + +class MetricType(Enum): + """Types of metrics to collect.""" + COUNTER = "counter" + GAUGE = "gauge" + HISTOGRAM = "histogram" + TIMER = "timer" + + +@dataclass +class Metric: + """Individual metric data point.""" + name: str + value: float + timestamp: datetime + labels: Dict[str, str] = field(default_factory=dict) + metric_type: MetricType = MetricType.GAUGE + + +@dataclass +class TimerMetric: + """Timer metric for measuring duration.""" + name: str + start_time: float + labels: Dict[str, str] = field(default_factory=dict) + + def stop(self) -> float: + """Stop timer and return duration.""" + return time.time() - self.start_time + + +@dataclass +class PerformanceStats: + """Performance statistics for a component.""" + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + avg_response_time: float = 0.0 + min_response_time: float = float('inf') + max_response_time: float = 0.0 + p95_response_time: float = 0.0 + p99_response_time: float = 0.0 + throughput_per_second: float = 0.0 + error_rate: float = 0.0 + + +@dataclass +class SystemHealth: + """Overall system health metrics.""" + status: str = "healthy" # healthy, degraded, unhealthy + uptime_seconds: float = 0.0 + cpu_usage_percent: float = 0.0 + memory_usage_percent: float = 0.0 + active_connections: int = 0 + queue_size: int = 0 + cache_hit_rate: float = 0.0 + error_rate: float = 0.0 + + +class MetricsCollector: + """Collects and stores metrics data.""" + + def __init__(self, max_metrics: int = 10000, retention_hours: int = 24): + """Initialize metrics collector. + + Args: + max_metrics: Maximum number of metrics to retain + retention_hours: Hours to retain metrics + """ + self.max_metrics = max_metrics + self.retention_hours = retention_hours + self.metrics: Dict[str, deque] = defaultdict(lambda: deque(maxlen=max_metrics)) + self.counters: Dict[str, float] = defaultdict(float) + self.gauges: Dict[str, float] = defaultdict(float) + self.timers: Dict[str, List[float]] = defaultdict(list) + self._lock = threading.RLock() + + def increment(self, name: str, value: float = 1.0, labels: Dict[str, str] = None): + """Increment a counter metric. + + Args: + name: Metric name + value: Value to increment by + labels: Metric labels + """ + with self._lock: + metric_key = self._build_key(name, labels) + self.counters[metric_key] += value + self._add_metric(name, value, MetricType.COUNTER, labels) + + def set_gauge(self, name: str, value: float, labels: Dict[str, str] = None): + """Set a gauge metric value. + + Args: + name: Metric name + value: Gauge value + labels: Metric labels + """ + with self._lock: + metric_key = self._build_key(name, labels) + self.gauges[metric_key] = value + self._add_metric(name, value, MetricType.GAUGE, labels) + + def record_timer(self, name: str, duration: float, labels: Dict[str, str] = None): + """Record a timer measurement. + + Args: + name: Metric name + duration: Duration in seconds + labels: Metric labels + """ + with self._lock: + metric_key = self._build_key(name, labels) + self.timers[metric_key].append(duration) + + # Keep only recent measurements + max_timer_values = 1000 + if len(self.timers[metric_key]) > max_timer_values: + self.timers[metric_key] = self.timers[metric_key][-max_timer_values:] + + self._add_metric(name, duration, MetricType.TIMER, labels) + + def start_timer(self, name: str, labels: Dict[str, str] = None) -> TimerMetric: + """Start a timer. + + Args: + name: Metric name + labels: Metric labels + + Returns: + Timer metric object + """ + return TimerMetric(name=name, start_time=time.time(), labels=labels or {}) + + def stop_timer(self, timer: TimerMetric): + """Stop a timer and record the measurement. + + Args: + timer: Timer metric to stop + """ + duration = timer.stop() + self.record_timer(timer.name, duration, timer.labels) + return duration + + def get_counter(self, name: str, labels: Dict[str, str] = None) -> float: + """Get counter value. + + Args: + name: Metric name + labels: Metric labels + + Returns: + Counter value + """ + metric_key = self._build_key(name, labels) + return self.counters.get(metric_key, 0.0) + + def get_gauge(self, name: str, labels: Dict[str, str] = None) -> float: + """Get gauge value. + + Args: + name: Metric name + labels: Metric labels + + Returns: + Gauge value + """ + metric_key = self._build_key(name, labels) + return self.gauges.get(metric_key, 0.0) + + def get_timer_stats(self, name: str, labels: Dict[str, str] = None) -> Dict[str, float]: + """Get timer statistics. + + Args: + name: Metric name + labels: Metric labels + + Returns: + Timer statistics + """ + metric_key = self._build_key(name, labels) + values = self.timers.get(metric_key, []) + + if not values: + return {} + + sorted_values = sorted(values) + return { + 'count': len(values), + 'sum': sum(values), + 'avg': statistics.mean(values), + 'min': min(values), + 'max': max(values), + 'p50': sorted_values[int(len(sorted_values) * 0.5)], + 'p95': sorted_values[int(len(sorted_values) * 0.95)], + 'p99': sorted_values[int(len(sorted_values) * 0.99)] + } + + def get_metrics(self, + name_pattern: Optional[str] = None, + since: Optional[datetime] = None) -> List[Metric]: + """Get metrics matching pattern and time range. + + Args: + name_pattern: Pattern to match metric names + since: Only return metrics since this time + + Returns: + List of matching metrics + """ + with self._lock: + results = [] + cutoff_time = since or datetime.now() - timedelta(hours=self.retention_hours) + + for metric_name, metric_queue in self.metrics.items(): + if name_pattern and name_pattern not in metric_name: + continue + + for metric in metric_queue: + if metric.timestamp >= cutoff_time: + results.append(metric) + + return sorted(results, key=lambda m: m.timestamp) + + def cleanup_old_metrics(self): + """Remove old metrics beyond retention period.""" + with self._lock: + cutoff_time = datetime.now() - timedelta(hours=self.retention_hours) + + for metric_name in list(self.metrics.keys()): + metric_queue = self.metrics[metric_name] + # Remove old metrics + while metric_queue and metric_queue[0].timestamp < cutoff_time: + metric_queue.popleft() + + # Remove empty queues + if not metric_queue: + del self.metrics[metric_name] + + def _add_metric(self, name: str, value: float, metric_type: MetricType, labels: Dict[str, str]): + """Add metric to storage.""" + metric = Metric( + name=name, + value=value, + timestamp=datetime.now(), + labels=labels or {}, + metric_type=metric_type + ) + self.metrics[name].append(metric) + + def _build_key(self, name: str, labels: Dict[str, str]) -> str: + """Build metric key from name and labels.""" + if not labels: + return name + + label_str = ','.join(f"{k}={v}" for k, v in sorted(labels.items())) + return f"{name}{{{label_str}}}" + + +class PerformanceMonitor: + """Monitors system performance and component health.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize performance monitor. + + Args: + config: Monitor configuration + """ + self.config = config or {} + self.metrics_collector = MetricsCollector( + max_metrics=self.config.get('max_metrics', 10000), + retention_hours=self.config.get('retention_hours', 24) + ) + + self.component_stats: Dict[str, PerformanceStats] = {} + self.start_time = time.time() + self.monitoring_enabled = self.config.get('enabled', True) + + # Start background monitoring tasks + if self.monitoring_enabled: + self._start_background_tasks() + + def record_request(self, + component: str, + operation: str, + duration: float, + success: bool = True, + labels: Dict[str, str] = None): + """Record a request completion. + + Args: + component: Component name + operation: Operation name + duration: Request duration in seconds + success: Whether request was successful + labels: Additional labels + """ + if not self.monitoring_enabled: + return + + base_labels = {'component': component, 'operation': operation} + if labels: + base_labels.update(labels) + + # Record metrics + self.metrics_collector.increment('requests_total', labels=base_labels) + self.metrics_collector.record_timer('request_duration', duration, base_labels) + + if success: + self.metrics_collector.increment('requests_successful', labels=base_labels) + else: + self.metrics_collector.increment('requests_failed', labels=base_labels) + + # Update component stats + self._update_component_stats(component, duration, success) + + def record_query_complexity(self, + complexity_score: float, + query_type: str, + backend: str): + """Record query complexity metrics. + + Args: + complexity_score: Query complexity score (0.0 to 1.0) + query_type: Type of query (SPARQL, Cypher) + backend: Backend used + """ + if not self.monitoring_enabled: + return + + labels = {'query_type': query_type, 'backend': backend} + self.metrics_collector.set_gauge('query_complexity', complexity_score, labels) + + def record_cache_access(self, hit: bool, cache_type: str = 'default'): + """Record cache access. + + Args: + hit: Whether it was a cache hit + cache_type: Type of cache + """ + if not self.monitoring_enabled: + return + + labels = {'cache_type': cache_type} + self.metrics_collector.increment('cache_requests_total', labels=labels) + + if hit: + self.metrics_collector.increment('cache_hits_total', labels=labels) + else: + self.metrics_collector.increment('cache_misses_total', labels=labels) + + def record_ontology_selection(self, + selected_elements: int, + total_elements: int, + ontology_id: str): + """Record ontology selection metrics. + + Args: + selected_elements: Number of selected ontology elements + total_elements: Total ontology elements + ontology_id: Ontology identifier + """ + if not self.monitoring_enabled: + return + + labels = {'ontology_id': ontology_id} + self.metrics_collector.set_gauge('ontology_elements_selected', selected_elements, labels) + self.metrics_collector.set_gauge('ontology_elements_total', total_elements, labels) + + selection_ratio = selected_elements / total_elements if total_elements > 0 else 0 + self.metrics_collector.set_gauge('ontology_selection_ratio', selection_ratio, labels) + + def get_component_stats(self, component: str) -> Optional[PerformanceStats]: + """Get performance statistics for a component. + + Args: + component: Component name + + Returns: + Performance statistics or None + """ + return self.component_stats.get(component) + + def get_system_health(self) -> SystemHealth: + """Get overall system health status. + + Returns: + System health metrics + """ + # Calculate uptime + uptime = time.time() - self.start_time + + # Get error rate + total_requests = self.metrics_collector.get_counter('requests_total') + failed_requests = self.metrics_collector.get_counter('requests_failed') + error_rate = failed_requests / total_requests if total_requests > 0 else 0.0 + + # Get cache hit rate + cache_hits = self.metrics_collector.get_counter('cache_hits_total') + cache_requests = self.metrics_collector.get_counter('cache_requests_total') + cache_hit_rate = cache_hits / cache_requests if cache_requests > 0 else 0.0 + + # Determine status + status = "healthy" + if error_rate > 0.1: # More than 10% error rate + status = "degraded" + if error_rate > 0.3: # More than 30% error rate + status = "unhealthy" + + return SystemHealth( + status=status, + uptime_seconds=uptime, + error_rate=error_rate, + cache_hit_rate=cache_hit_rate + ) + + def get_performance_report(self) -> Dict[str, Any]: + """Get comprehensive performance report. + + Returns: + Performance report + """ + report = { + 'system_health': self.get_system_health(), + 'component_stats': {}, + 'top_slow_operations': [], + 'error_patterns': {}, + 'cache_performance': {}, + 'ontology_usage': {} + } + + # Component statistics + for component, stats in self.component_stats.items(): + report['component_stats'][component] = stats + + # Top slow operations + timer_stats = {} + for metric_name in self.metrics_collector.timers.keys(): + if 'request_duration' in metric_name: + stats = self.metrics_collector.get_timer_stats(metric_name) + if stats: + timer_stats[metric_name] = stats + + # Sort by p95 latency + slow_ops = sorted( + timer_stats.items(), + key=lambda x: x[1].get('p95', 0), + reverse=True + )[:10] + + report['top_slow_operations'] = [ + {'operation': op, 'stats': stats} for op, stats in slow_ops + ] + + # Cache performance + cache_types = set() + for metric_name in self.metrics_collector.counters.keys(): + if 'cache_type=' in metric_name: + cache_type = metric_name.split('cache_type=')[1].split(',')[0].split('}')[0] + cache_types.add(cache_type) + + for cache_type in cache_types: + labels = {'cache_type': cache_type} + hits = self.metrics_collector.get_counter('cache_hits_total', labels) + requests = self.metrics_collector.get_counter('cache_requests_total', labels) + hit_rate = hits / requests if requests > 0 else 0.0 + + report['cache_performance'][cache_type] = { + 'hit_rate': hit_rate, + 'total_requests': requests, + 'total_hits': hits + } + + return report + + def _update_component_stats(self, component: str, duration: float, success: bool): + """Update component performance statistics.""" + if component not in self.component_stats: + self.component_stats[component] = PerformanceStats() + + stats = self.component_stats[component] + stats.total_requests += 1 + + if success: + stats.successful_requests += 1 + else: + stats.failed_requests += 1 + + # Update response time stats + stats.min_response_time = min(stats.min_response_time, duration) + stats.max_response_time = max(stats.max_response_time, duration) + + # Get timer stats for percentiles + timer_stats = self.metrics_collector.get_timer_stats( + 'request_duration', {'component': component} + ) + + if timer_stats: + stats.avg_response_time = timer_stats.get('avg', 0.0) + stats.p95_response_time = timer_stats.get('p95', 0.0) + stats.p99_response_time = timer_stats.get('p99', 0.0) + + # Calculate rates + stats.error_rate = stats.failed_requests / stats.total_requests + + # Calculate throughput (requests per second over last minute) + recent_requests = len([ + m for m in self.metrics_collector.get_metrics('requests_total') + if m.labels.get('component') == component and + m.timestamp > datetime.now() - timedelta(minutes=1) + ]) + stats.throughput_per_second = recent_requests / 60.0 + + def _start_background_tasks(self): + """Start background monitoring tasks.""" + def cleanup_worker(): + """Worker to clean up old metrics.""" + while True: + try: + time.sleep(300) # 5 minutes + self.metrics_collector.cleanup_old_metrics() + except Exception as e: + logger.error(f"Metrics cleanup error: {e}") + + cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) + cleanup_thread.start() + + +# Monitoring decorators + +def monitor_performance(component: str, + operation: str, + monitor: Optional[PerformanceMonitor] = None): + """Decorator to monitor function performance. + + Args: + component: Component name + operation: Operation name + monitor: Performance monitor instance + """ + def decorator(func): + def wrapper(*args, **kwargs): + if not monitor or not monitor.monitoring_enabled: + return func(*args, **kwargs) + + timer = monitor.metrics_collector.start_timer( + 'request_duration', + {'component': component, 'operation': operation} + ) + + success = True + try: + result = func(*args, **kwargs) + return result + except Exception as e: + success = False + raise + finally: + duration = monitor.metrics_collector.stop_timer(timer) + monitor.record_request(component, operation, duration, success) + + async def async_wrapper(*args, **kwargs): + if not monitor or not monitor.monitoring_enabled: + if asyncio.iscoroutinefunction(func): + return await func(*args, **kwargs) + else: + return func(*args, **kwargs) + + timer = monitor.metrics_collector.start_timer( + 'request_duration', + {'component': component, 'operation': operation} + ) + + success = True + try: + if asyncio.iscoroutinefunction(func): + result = await func(*args, **kwargs) + else: + result = func(*args, **kwargs) + return result + except Exception as e: + success = False + raise + finally: + duration = monitor.metrics_collector.stop_timer(timer) + monitor.record_request(component, operation, duration, success) + + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return wrapper + + return decorator + + +class QueryPatternAnalyzer: + """Analyzes query patterns for optimization insights.""" + + def __init__(self, monitor: PerformanceMonitor): + """Initialize query pattern analyzer. + + Args: + monitor: Performance monitor instance + """ + self.monitor = monitor + self.query_patterns: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + + def record_query_pattern(self, + question_type: str, + entities: List[str], + complexity: float, + backend: str, + duration: float, + success: bool): + """Record a query pattern for analysis. + + Args: + question_type: Type of question + entities: Entities in question + complexity: Query complexity score + backend: Backend used + duration: Query duration + success: Whether query succeeded + """ + pattern = { + 'timestamp': datetime.now(), + 'question_type': question_type, + 'entity_count': len(entities), + 'entities': entities, + 'complexity': complexity, + 'backend': backend, + 'duration': duration, + 'success': success + } + + pattern_key = f"{question_type}:{len(entities)}" + self.query_patterns[pattern_key].append(pattern) + + # Keep only recent patterns + cutoff_time = datetime.now() - timedelta(hours=24) + self.query_patterns[pattern_key] = [ + p for p in self.query_patterns[pattern_key] + if p['timestamp'] > cutoff_time + ] + + def get_optimization_insights(self) -> Dict[str, Any]: + """Get insights for query optimization. + + Returns: + Optimization insights and recommendations + """ + insights = { + 'slow_patterns': [], + 'common_failures': [], + 'backend_performance': {}, + 'complexity_analysis': {}, + 'recommendations': [] + } + + # Analyze slow patterns + for pattern_key, patterns in self.query_patterns.items(): + if not patterns: + continue + + avg_duration = statistics.mean([p['duration'] for p in patterns]) + success_rate = sum(1 for p in patterns if p['success']) / len(patterns) + + if avg_duration > 5.0: # Slow queries > 5 seconds + insights['slow_patterns'].append({ + 'pattern': pattern_key, + 'avg_duration': avg_duration, + 'count': len(patterns), + 'success_rate': success_rate + }) + + if success_rate < 0.8: # Low success rate + insights['common_failures'].append({ + 'pattern': pattern_key, + 'success_rate': success_rate, + 'count': len(patterns) + }) + + # Analyze backend performance + backend_stats = defaultdict(list) + for patterns in self.query_patterns.values(): + for pattern in patterns: + backend_stats[pattern['backend']].append(pattern['duration']) + + for backend, durations in backend_stats.items(): + insights['backend_performance'][backend] = { + 'avg_duration': statistics.mean(durations), + 'p95_duration': sorted(durations)[int(len(durations) * 0.95)], + 'query_count': len(durations) + } + + # Generate recommendations + recommendations = [] + + # Slow pattern recommendations + for slow_pattern in insights['slow_patterns']: + recommendations.append( + f"Consider optimizing {slow_pattern['pattern']} queries - " + f"average duration {slow_pattern['avg_duration']:.2f}s" + ) + + # Backend recommendations + if len(insights['backend_performance']) > 1: + fastest_backend = min( + insights['backend_performance'].items(), + key=lambda x: x[1]['avg_duration'] + )[0] + recommendations.append( + f"Consider routing more queries to {fastest_backend} " + f"for better performance" + ) + + insights['recommendations'] = recommendations + + return insights \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/multi_language.py b/trustgraph-flow/trustgraph/query/ontology/multi_language.py new file mode 100644 index 00000000..d7b7883a --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/multi_language.py @@ -0,0 +1,656 @@ +""" +Multi-language support for OntoRAG. +Provides language detection, translation, and multilingual query processing. +""" + +import logging +from typing import Dict, Any, List, Optional, Tuple +from dataclasses import dataclass +from enum import Enum + +logger = logging.getLogger(__name__) + + +class Language(Enum): + """Supported languages.""" + ENGLISH = "en" + SPANISH = "es" + FRENCH = "fr" + GERMAN = "de" + ITALIAN = "it" + PORTUGUESE = "pt" + CHINESE = "zh" + JAPANESE = "ja" + KOREAN = "ko" + ARABIC = "ar" + RUSSIAN = "ru" + DUTCH = "nl" + + +@dataclass +class LanguageDetectionResult: + """Language detection result.""" + language: Language + confidence: float + detected_text: str + alternative_languages: List[Tuple[Language, float]] = None + + +@dataclass +class TranslationResult: + """Translation result.""" + original_text: str + translated_text: str + source_language: Language + target_language: Language + confidence: float + + +class LanguageDetector: + """Detects language of input text.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize language detector. + + Args: + config: Detector configuration + """ + self.config = config or {} + self.default_language = Language(self.config.get('default_language', 'en')) + self.confidence_threshold = self.config.get('confidence_threshold', 0.7) + + # Try to import language detection libraries + self.detector = None + self._init_detector() + + def _init_detector(self): + """Initialize language detection backend.""" + try: + # Try langdetect first + import langdetect + self.detector = 'langdetect' + logger.info("Using langdetect for language detection") + except ImportError: + try: + # Try textblob as fallback + from textblob import TextBlob + self.detector = 'textblob' + logger.info("Using TextBlob for language detection") + except ImportError: + logger.warning("No language detection library available, using rule-based detection") + self.detector = 'rule_based' + + def detect_language(self, text: str) -> LanguageDetectionResult: + """Detect language of input text. + + Args: + text: Text to analyze + + Returns: + Language detection result + """ + if not text or not text.strip(): + return LanguageDetectionResult( + language=self.default_language, + confidence=0.0, + detected_text=text + ) + + try: + if self.detector == 'langdetect': + return self._detect_with_langdetect(text) + elif self.detector == 'textblob': + return self._detect_with_textblob(text) + else: + return self._detect_with_rules(text) + + except Exception as e: + logger.error(f"Language detection failed: {e}") + return LanguageDetectionResult( + language=self.default_language, + confidence=0.0, + detected_text=text + ) + + def _detect_with_langdetect(self, text: str) -> LanguageDetectionResult: + """Detect language using langdetect library.""" + import langdetect + from langdetect.lang_detect_exception import LangDetectException + + try: + # Get detailed detection results + probabilities = langdetect.detect_langs(text) + + if not probabilities: + return LanguageDetectionResult( + language=self.default_language, + confidence=0.0, + detected_text=text + ) + + best_match = probabilities[0] + detected_lang_code = best_match.lang + confidence = best_match.prob + + # Map to our Language enum + try: + detected_language = Language(detected_lang_code) + except ValueError: + # Map common variations + lang_mapping = { + 'ca': Language.SPANISH, # Catalan -> Spanish + 'eu': Language.SPANISH, # Basque -> Spanish + 'gl': Language.SPANISH, # Galician -> Spanish + 'zh-cn': Language.CHINESE, + 'zh-tw': Language.CHINESE, + } + detected_language = lang_mapping.get(detected_lang_code, self.default_language) + + # Get alternatives + alternatives = [] + for lang_prob in probabilities[1:3]: # Top 3 alternatives + try: + alt_lang = Language(lang_prob.lang) + alternatives.append((alt_lang, lang_prob.prob)) + except ValueError: + continue + + return LanguageDetectionResult( + language=detected_language, + confidence=confidence, + detected_text=text, + alternative_languages=alternatives + ) + + except LangDetectException: + return LanguageDetectionResult( + language=self.default_language, + confidence=0.0, + detected_text=text + ) + + def _detect_with_textblob(self, text: str) -> LanguageDetectionResult: + """Detect language using TextBlob.""" + from textblob import TextBlob + + try: + blob = TextBlob(text) + detected_lang_code = blob.detect_language() + + try: + detected_language = Language(detected_lang_code) + except ValueError: + detected_language = self.default_language + + # TextBlob doesn't provide confidence, so estimate based on text length + confidence = min(0.8, len(text) / 100.0) if len(text) > 10 else 0.5 + + return LanguageDetectionResult( + language=detected_language, + confidence=confidence, + detected_text=text + ) + + except Exception: + return LanguageDetectionResult( + language=self.default_language, + confidence=0.0, + detected_text=text + ) + + def _detect_with_rules(self, text: str) -> LanguageDetectionResult: + """Rule-based language detection fallback.""" + text_lower = text.lower() + + # Simple keyword-based detection + language_keywords = { + Language.SPANISH: ['qué', 'cuál', 'cuándo', 'dónde', 'cómo', 'por qué', 'cuántos'], + Language.FRENCH: ['que', 'quel', 'quand', 'où', 'comment', 'pourquoi', 'combien'], + Language.GERMAN: ['was', 'welche', 'wann', 'wo', 'wie', 'warum', 'wieviele'], + Language.ITALIAN: ['che', 'quale', 'quando', 'dove', 'come', 'perché', 'quanti'], + Language.PORTUGUESE: ['que', 'qual', 'quando', 'onde', 'como', 'por que', 'quantos'], + Language.DUTCH: ['wat', 'welke', 'wanneer', 'waar', 'hoe', 'waarom', 'hoeveel'] + } + + best_match = self.default_language + best_score = 0 + + for language, keywords in language_keywords.items(): + score = sum(1 for keyword in keywords if keyword in text_lower) + if score > best_score: + best_score = score + best_match = language + + confidence = min(0.8, best_score / 3.0) if best_score > 0 else 0.1 + + return LanguageDetectionResult( + language=best_match, + confidence=confidence, + detected_text=text + ) + + +class TextTranslator: + """Translates text between languages.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize text translator. + + Args: + config: Translator configuration + """ + self.config = config or {} + self.translator = None + self._init_translator() + + def _init_translator(self): + """Initialize translation backend.""" + try: + # Try Google Translate first + from googletrans import Translator + self.translator = Translator() + self.backend = 'googletrans' + logger.info("Using Google Translate for translation") + except ImportError: + try: + # Try TextBlob as fallback + from textblob import TextBlob + self.backend = 'textblob' + logger.info("Using TextBlob for translation") + except ImportError: + logger.warning("No translation library available") + self.backend = None + + def translate(self, + text: str, + target_language: Language, + source_language: Optional[Language] = None) -> TranslationResult: + """Translate text to target language. + + Args: + text: Text to translate + target_language: Target language + source_language: Source language (auto-detect if None) + + Returns: + Translation result + """ + if not text or not text.strip(): + return TranslationResult( + original_text=text, + translated_text=text, + source_language=source_language or Language.ENGLISH, + target_language=target_language, + confidence=0.0 + ) + + try: + if self.backend == 'googletrans': + return self._translate_with_googletrans(text, target_language, source_language) + elif self.backend == 'textblob': + return self._translate_with_textblob(text, target_language, source_language) + else: + # No translation available + return TranslationResult( + original_text=text, + translated_text=text, + source_language=source_language or Language.ENGLISH, + target_language=target_language, + confidence=0.0 + ) + + except Exception as e: + logger.error(f"Translation failed: {e}") + return TranslationResult( + original_text=text, + translated_text=text, + source_language=source_language or Language.ENGLISH, + target_language=target_language, + confidence=0.0 + ) + + def _translate_with_googletrans(self, + text: str, + target_language: Language, + source_language: Optional[Language]) -> TranslationResult: + """Translate using Google Translate.""" + try: + src_code = source_language.value if source_language else 'auto' + dest_code = target_language.value + + result = self.translator.translate(text, src=src_code, dest=dest_code) + + detected_source = Language(result.src) if result.src != 'auto' else Language.ENGLISH + confidence = 0.9 # Google Translate is generally reliable + + return TranslationResult( + original_text=text, + translated_text=result.text, + source_language=detected_source, + target_language=target_language, + confidence=confidence + ) + + except Exception as e: + logger.error(f"Google Translate error: {e}") + raise + + def _translate_with_textblob(self, + text: str, + target_language: Language, + source_language: Optional[Language]) -> TranslationResult: + """Translate using TextBlob.""" + from textblob import TextBlob + + try: + blob = TextBlob(text) + + if not source_language: + # Auto-detect source language + detected_lang = blob.detect_language() + try: + source_language = Language(detected_lang) + except ValueError: + source_language = Language.ENGLISH + + translated_blob = blob.translate(to=target_language.value) + translated_text = str(translated_blob) + + # TextBlob confidence estimation + confidence = 0.7 if len(text) > 10 else 0.5 + + return TranslationResult( + original_text=text, + translated_text=translated_text, + source_language=source_language, + target_language=target_language, + confidence=confidence + ) + + except Exception as e: + logger.error(f"TextBlob translation error: {e}") + raise + + +class MultiLanguageQueryProcessor: + """Processes queries in multiple languages.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize multi-language query processor. + + Args: + config: Processor configuration + """ + self.config = config or {} + self.language_detector = LanguageDetector(config.get('language_detection', {})) + self.translator = TextTranslator(config.get('translation', {})) + self.supported_languages = [Language(lang) for lang in config.get('supported_languages', ['en'])] + self.primary_language = Language(config.get('primary_language', 'en')) + + async def process_multilingual_query(self, question: str) -> Dict[str, Any]: + """Process a query in any supported language. + + Args: + question: Question in any language + + Returns: + Processing result with language information + """ + # Step 1: Detect language + detection_result = self.language_detector.detect_language(question) + detected_language = detection_result.language + + logger.info(f"Detected language: {detected_language.value} " + f"(confidence: {detection_result.confidence:.2f})") + + # Step 2: Translate to primary language if needed + translated_question = question + translation_result = None + + if detected_language != self.primary_language: + if detection_result.confidence >= self.language_detector.confidence_threshold: + translation_result = self.translator.translate( + question, self.primary_language, detected_language + ) + translated_question = translation_result.translated_text + logger.info(f"Translated question: {translated_question}") + else: + logger.warning(f"Low confidence language detection, processing in {self.primary_language.value}") + + # Step 3: Return processing information + return { + 'original_question': question, + 'translated_question': translated_question, + 'detected_language': detected_language, + 'detection_confidence': detection_result.confidence, + 'translation_result': translation_result, + 'processing_language': self.primary_language, + 'alternative_languages': detection_result.alternative_languages + } + + async def translate_answer(self, + answer: str, + target_language: Language) -> TranslationResult: + """Translate answer back to target language. + + Args: + answer: Answer in primary language + target_language: Target language for answer + + Returns: + Translation result + """ + if target_language == self.primary_language: + # No translation needed + return TranslationResult( + original_text=answer, + translated_text=answer, + source_language=self.primary_language, + target_language=target_language, + confidence=1.0 + ) + + return self.translator.translate(answer, target_language, self.primary_language) + + def get_language_specific_ontology_terms(self, + ontology_subset: Dict[str, Any], + language: Language) -> Dict[str, Any]: + """Get language-specific terms from ontology. + + Args: + ontology_subset: Ontology subset + language: Target language + + Returns: + Language-specific ontology terms + """ + # Extract language-specific labels and descriptions + lang_code = language.value + result = {} + + # Process classes + if 'classes' in ontology_subset: + result['classes'] = {} + for class_id, class_def in ontology_subset['classes'].items(): + lang_labels = [] + if 'labels' in class_def: + for label in class_def['labels']: + if isinstance(label, dict) and label.get('language') == lang_code: + lang_labels.append(label['value']) + elif isinstance(label, str): + lang_labels.append(label) + + result['classes'][class_id] = { + **class_def, + 'language_labels': lang_labels + } + + # Process properties + for prop_type in ['object_properties', 'datatype_properties']: + if prop_type in ontology_subset: + result[prop_type] = {} + for prop_id, prop_def in ontology_subset[prop_type].items(): + lang_labels = [] + if 'labels' in prop_def: + for label in prop_def['labels']: + if isinstance(label, dict) and label.get('language') == lang_code: + lang_labels.append(label['value']) + elif isinstance(label, str): + lang_labels.append(label) + + result[prop_type][prop_id] = { + **prop_def, + 'language_labels': lang_labels + } + + return result + + def is_language_supported(self, language: Language) -> bool: + """Check if language is supported. + + Args: + language: Language to check + + Returns: + True if language is supported + """ + return language in self.supported_languages + + def get_supported_languages(self) -> List[Language]: + """Get list of supported languages. + + Returns: + List of supported languages + """ + return self.supported_languages.copy() + + def add_language_support(self, language: Language): + """Add support for a new language. + + Args: + language: Language to add support for + """ + if language not in self.supported_languages: + self.supported_languages.append(language) + logger.info(f"Added support for language: {language.value}") + + def remove_language_support(self, language: Language): + """Remove support for a language. + + Args: + language: Language to remove support for + """ + if language in self.supported_languages and language != self.primary_language: + self.supported_languages.remove(language) + logger.info(f"Removed support for language: {language.value}") + else: + logger.warning(f"Cannot remove primary language or unsupported language: {language.value}") + + +class LanguageSpecificTemplates: + """Manages language-specific query and answer templates.""" + + def __init__(self): + """Initialize language-specific templates.""" + self.question_templates = { + Language.ENGLISH: { + 'count': ['how many', 'count of', 'number of'], + 'boolean': ['is', 'are', 'does', 'can', 'will'], + 'retrieval': ['what', 'which', 'who', 'where'], + 'factual': ['tell me about', 'describe', 'explain'] + }, + Language.SPANISH: { + 'count': ['cuántos', 'cuántas', 'número de', 'cantidad de'], + 'boolean': ['es', 'son', 'está', 'están', 'puede', 'pueden'], + 'retrieval': ['qué', 'cuál', 'cuáles', 'quién', 'dónde'], + 'factual': ['dime sobre', 'describe', 'explica'] + }, + Language.FRENCH: { + 'count': ['combien', 'nombre de', 'quantité de'], + 'boolean': ['est', 'sont', 'peut', 'peuvent'], + 'retrieval': ['que', 'quel', 'quelle', 'qui', 'où'], + 'factual': ['dis-moi sur', 'décris', 'explique'] + }, + Language.GERMAN: { + 'count': ['wie viele', 'anzahl der', 'zahl der'], + 'boolean': ['ist', 'sind', 'kann', 'können'], + 'retrieval': ['was', 'welche', 'wer', 'wo'], + 'factual': ['erzähl mir über', 'beschreibe', 'erkläre'] + } + } + + self.answer_templates = { + Language.ENGLISH: { + 'count': 'There are {count} {entity}.', + 'boolean_true': 'Yes, {statement}.', + 'boolean_false': 'No, {statement}.', + 'not_found': 'No information found.', + 'error': 'Sorry, I encountered an error.' + }, + Language.SPANISH: { + 'count': 'Hay {count} {entity}.', + 'boolean_true': 'Sí, {statement}.', + 'boolean_false': 'No, {statement}.', + 'not_found': 'No se encontró información.', + 'error': 'Lo siento, encontré un error.' + }, + Language.FRENCH: { + 'count': 'Il y a {count} {entity}.', + 'boolean_true': 'Oui, {statement}.', + 'boolean_false': 'Non, {statement}.', + 'not_found': 'Aucune information trouvée.', + 'error': 'Désolé, j\'ai rencontré une erreur.' + }, + Language.GERMAN: { + 'count': 'Es gibt {count} {entity}.', + 'boolean_true': 'Ja, {statement}.', + 'boolean_false': 'Nein, {statement}.', + 'not_found': 'Keine Informationen gefunden.', + 'error': 'Entschuldigung, ich bin auf einen Fehler gestoßen.' + } + } + + def get_question_patterns(self, language: Language) -> Dict[str, List[str]]: + """Get question patterns for a language. + + Args: + language: Target language + + Returns: + Dictionary of question patterns + """ + return self.question_templates.get(language, self.question_templates[Language.ENGLISH]) + + def get_answer_template(self, language: Language, template_type: str) -> str: + """Get answer template for a language and type. + + Args: + language: Target language + template_type: Template type + + Returns: + Answer template string + """ + templates = self.answer_templates.get(language, self.answer_templates[Language.ENGLISH]) + return templates.get(template_type, templates.get('error', 'Error')) + + def format_answer(self, + language: Language, + template_type: str, + **kwargs) -> str: + """Format answer using language-specific template. + + Args: + language: Target language + template_type: Template type + **kwargs: Template variables + + Returns: + Formatted answer + """ + template = self.get_answer_template(language, template_type) + try: + return template.format(**kwargs) + except KeyError as e: + logger.error(f"Missing template variable: {e}") + return self.get_answer_template(language, 'error') \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py b/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py new file mode 100644 index 00000000..895856f3 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py @@ -0,0 +1,256 @@ +""" +Ontology matcher for query system. +Identifies relevant ontology subsets for answering questions. +""" + +import logging +from typing import List, Dict, Any, Set, Optional +from dataclasses import dataclass + +from ...extract.kg.ontology.ontology_loader import Ontology, OntologyLoader +from ...extract.kg.ontology.ontology_embedder import OntologyEmbedder +from ...extract.kg.ontology.text_processor import TextSegment +from ...extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset +from .question_analyzer import QuestionComponents, QuestionType + +logger = logging.getLogger(__name__) + + +@dataclass +class QueryOntologySubset(OntologySubset): + """Extended ontology subset for query processing.""" + traversal_properties: Dict[str, Any] = None # Additional properties for graph traversal + inference_rules: List[Dict[str, Any]] = None # Inference rules for reasoning + + +class OntologyMatcherForQueries(OntologySelector): + """ + Specialized ontology matcher for question answering. + Extends OntologySelector with query-specific logic. + """ + + def __init__(self, ontology_embedder: OntologyEmbedder, + ontology_loader: OntologyLoader, + top_k: int = 15, # Higher k for queries + similarity_threshold: float = 0.6): # Lower threshold for broader coverage + """Initialize query-specific ontology matcher. + + Args: + ontology_embedder: Embedder with vector store + ontology_loader: Loader with ontology definitions + top_k: Number of top results to retrieve + similarity_threshold: Minimum similarity score + """ + super().__init__(ontology_embedder, ontology_loader, top_k, similarity_threshold) + + async def match_question_to_ontology(self, + question_components: QuestionComponents, + question_segments: List[str]) -> List[QueryOntologySubset]: + """Match question components to relevant ontology elements. + + Args: + question_components: Analyzed question components + question_segments: Text segments from question + + Returns: + List of query-optimized ontology subsets + """ + # Convert question segments to TextSegment objects + text_segments = [ + TextSegment(text=seg, type='question', position=i) + for i, seg in enumerate(question_segments) + ] + + # Get base ontology subsets using parent class method + base_subsets = await self.select_ontology_subset(text_segments) + + # Enhance subsets for query processing + query_subsets = [] + for subset in base_subsets: + query_subset = self._enhance_for_query(subset, question_components) + query_subsets.append(query_subset) + + return query_subsets + + def _enhance_for_query(self, subset: OntologySubset, + question_components: QuestionComponents) -> QueryOntologySubset: + """Enhance ontology subset with query-specific elements. + + Args: + subset: Base ontology subset + question_components: Analyzed question components + + Returns: + Enhanced query ontology subset + """ + # Create query subset + query_subset = QueryOntologySubset( + ontology_id=subset.ontology_id, + classes=dict(subset.classes), + object_properties=dict(subset.object_properties), + datatype_properties=dict(subset.datatype_properties), + metadata=subset.metadata, + relevance_score=subset.relevance_score, + traversal_properties={}, + inference_rules=[] + ) + + # Add traversal properties based on question type + self._add_traversal_properties(query_subset, question_components) + + # Add related properties for exploration + self._add_related_properties(query_subset) + + # Add inference rules if needed + self._add_inference_rules(query_subset, question_components) + + return query_subset + + def _add_traversal_properties(self, subset: QueryOntologySubset, + question_components: QuestionComponents): + """Add properties useful for graph traversal. + + Args: + subset: Query ontology subset to enhance + question_components: Question analysis + """ + ontology = self.loader.get_ontology(subset.ontology_id) + if not ontology: + return + + # For relationship questions, add all properties connecting mentioned classes + if question_components.question_type == QuestionType.RELATIONSHIP: + for prop_id, prop_def in ontology.object_properties.items(): + domain = prop_def.domain + range_val = prop_def.range + + # Check if property connects relevant classes + if domain in subset.classes or range_val in subset.classes: + if prop_id not in subset.object_properties: + subset.traversal_properties[prop_id] = prop_def.__dict__ + logger.debug(f"Added traversal property: {prop_id}") + + # For retrieval questions, add properties that might filter results + elif question_components.question_type == QuestionType.RETRIEVAL: + # Add all properties with domains in our classes + for prop_id, prop_def in ontology.object_properties.items(): + if prop_def.domain in subset.classes: + if prop_id not in subset.object_properties: + subset.traversal_properties[prop_id] = prop_def.__dict__ + + for prop_id, prop_def in ontology.datatype_properties.items(): + if prop_def.domain in subset.classes: + if prop_id not in subset.datatype_properties: + subset.traversal_properties[prop_id] = prop_def.__dict__ + + # For aggregation questions, ensure we have counting properties + elif question_components.question_type == QuestionType.AGGREGATION: + # Add properties that might be counted + for prop_id, prop_def in ontology.datatype_properties.items(): + if 'count' in prop_id.lower() or 'number' in prop_id.lower(): + if prop_id not in subset.datatype_properties: + subset.traversal_properties[prop_id] = prop_def.__dict__ + + def _add_related_properties(self, subset: QueryOntologySubset): + """Add properties related to already selected ones. + + Args: + subset: Query ontology subset to enhance + """ + ontology = self.loader.get_ontology(subset.ontology_id) + if not ontology: + return + + # Add inverse properties + for prop_id in list(subset.object_properties.keys()): + prop = ontology.object_properties.get(prop_id) + if prop and prop.inverse_of: + inverse_prop = ontology.object_properties.get(prop.inverse_of) + if inverse_prop and prop.inverse_of not in subset.object_properties: + subset.object_properties[prop.inverse_of] = inverse_prop.__dict__ + logger.debug(f"Added inverse property: {prop.inverse_of}") + + # Add sibling properties (same domain) + domains_in_subset = set() + for prop_def in subset.object_properties.values(): + if 'domain' in prop_def and prop_def['domain']: + domains_in_subset.add(prop_def['domain']) + + for domain in domains_in_subset: + for prop_id, prop_def in ontology.object_properties.items(): + if prop_def.domain == domain and prop_id not in subset.object_properties: + # Add up to 3 sibling properties + if len(subset.traversal_properties) < 3: + subset.traversal_properties[prop_id] = prop_def.__dict__ + + def _add_inference_rules(self, subset: QueryOntologySubset, + question_components: QuestionComponents): + """Add inference rules for reasoning. + + Args: + subset: Query ontology subset to enhance + question_components: Question analysis + """ + # Add transitivity rules for subclass relationships + if any(cls.get('subclass_of') for cls in subset.classes.values()): + subset.inference_rules.append({ + 'type': 'transitivity', + 'property': 'rdfs:subClassOf', + 'description': 'Subclass relationships are transitive' + }) + + # Add symmetry rules for equivalent classes + if any(cls.get('equivalent_classes') for cls in subset.classes.values()): + subset.inference_rules.append({ + 'type': 'symmetry', + 'property': 'owl:equivalentClass', + 'description': 'Equivalent class relationships are symmetric' + }) + + # Add inverse property rules + for prop_id, prop_def in subset.object_properties.items(): + if 'inverse_of' in prop_def and prop_def['inverse_of']: + subset.inference_rules.append({ + 'type': 'inverse', + 'property': prop_id, + 'inverse': prop_def['inverse_of'], + 'description': f'{prop_id} is inverse of {prop_def["inverse_of"]}' + }) + + def expand_for_hierarchical_queries(self, subset: QueryOntologySubset) -> QueryOntologySubset: + """Expand subset to include full class hierarchies. + + Args: + subset: Query ontology subset + + Returns: + Expanded subset with complete hierarchies + """ + ontology = self.loader.get_ontology(subset.ontology_id) + if not ontology: + return subset + + # Add all parent and child classes + classes_to_add = set() + for class_id in list(subset.classes.keys()): + # Add all parents + parents = ontology.get_parent_classes(class_id) + for parent_id in parents: + if parent_id not in subset.classes: + parent_class = ontology.get_class(parent_id) + if parent_class: + classes_to_add.add(parent_id) + + # Add all children + for other_class_id, other_class in ontology.classes.items(): + if other_class.subclass_of == class_id and other_class_id not in subset.classes: + classes_to_add.add(other_class_id) + + # Add collected classes + for class_id in classes_to_add: + cls = ontology.get_class(class_id) + if cls: + subset.classes[class_id] = cls.__dict__ + + logger.debug(f"Expanded hierarchy: added {len(classes_to_add)} classes") + return subset \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/query_explanation.py b/trustgraph-flow/trustgraph/query/ontology/query_explanation.py new file mode 100644 index 00000000..bd72aedc --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/query_explanation.py @@ -0,0 +1,640 @@ +""" +Query explanation system for OntoRAG. +Provides detailed explanations of how queries are processed and results are derived. +""" + +import logging +from typing import Dict, Any, List, Optional, Union +from dataclasses import dataclass, field +from datetime import datetime + +from .question_analyzer import QuestionComponents, QuestionType +from .ontology_matcher import QueryOntologySubset +from .sparql_generator import SPARQLQuery +from .cypher_generator import CypherQuery +from .sparql_cassandra import SPARQLResult +from .cypher_executor import CypherResult + +logger = logging.getLogger(__name__) + + +@dataclass +class ExplanationStep: + """Individual step in query explanation.""" + step_number: int + component: str + operation: str + input_data: Dict[str, Any] + output_data: Dict[str, Any] + explanation: str + duration_ms: float + success: bool + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class QueryExplanation: + """Complete explanation of query processing.""" + query_id: str + original_question: str + processing_steps: List[ExplanationStep] + final_answer: str + confidence_score: float + total_duration_ms: float + ontologies_used: List[str] + backend_used: str + reasoning_chain: List[str] + technical_details: Dict[str, Any] + user_friendly_explanation: str + + +class QueryExplainer: + """Generates explanations for query processing.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize query explainer. + + Args: + config: Explainer configuration + """ + self.config = config or {} + self.explanation_level = self.config.get('explanation_level', 'detailed') # basic, detailed, technical + self.include_technical_details = self.config.get('include_technical_details', True) + self.max_reasoning_steps = self.config.get('max_reasoning_steps', 10) + + # Templates for different explanation types + self.step_templates = { + 'question_analysis': { + 'basic': "I analyzed your question to understand what you're asking.", + 'detailed': "I analyzed your question '{question}' and identified it as a {question_type} query about {entities}.", + 'technical': "Question analysis: Type={question_type}, Entities={entities}, Keywords={keywords}, Expected answer={answer_type}" + }, + 'ontology_matching': { + 'basic': "I found relevant knowledge about {entities} in the available ontologies.", + 'detailed': "I searched through {ontology_count} ontologies and found {selected_elements} relevant concepts related to your question.", + 'technical': "Ontology matching: Selected {classes} classes, {properties} properties from {ontologies}" + }, + 'query_generation': { + 'basic': "I generated a query to search for the information.", + 'detailed': "I created a {query_type} query using {query_language} to search the {backend} database.", + 'technical': "Query generation: {query_language} query with {variables} variables, complexity score {complexity}" + }, + 'query_execution': { + 'basic': "I searched the database and found {result_count} results.", + 'detailed': "I executed the query against the {backend} database and retrieved {result_count} results in {duration}ms.", + 'technical': "Query execution: {backend} backend, {result_count} results, execution time {duration}ms" + }, + 'answer_generation': { + 'basic': "I generated a natural language answer from the results.", + 'detailed': "I processed {result_count} results and generated an answer with {confidence}% confidence.", + 'technical': "Answer generation: {result_count} input results, {generation_method} method, confidence {confidence}" + } + } + + self.reasoning_templates = { + 'entity_identification': "I identified '{entity}' as a key concept in your question.", + 'ontology_selection': "I selected the '{ontology}' ontology because it contains relevant information about {concepts}.", + 'query_strategy': "I chose a {strategy} query approach because {reason}.", + 'result_filtering': "I filtered the results to show only the most relevant {count} items.", + 'confidence_assessment': "I'm {confidence}% confident in this answer because {reasoning}." + } + + def explain_query_processing(self, + question: str, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset], + generated_query: Union[SPARQLQuery, CypherQuery], + query_results: Union[SPARQLResult, CypherResult], + final_answer: str, + processing_metadata: Dict[str, Any]) -> QueryExplanation: + """Generate comprehensive explanation of query processing. + + Args: + question: Original question + question_components: Analyzed question components + ontology_subsets: Selected ontology subsets + generated_query: Generated query + query_results: Query execution results + final_answer: Final generated answer + processing_metadata: Processing metadata + + Returns: + Complete query explanation + """ + query_id = processing_metadata.get('query_id', f"query_{datetime.now().timestamp()}") + start_time = processing_metadata.get('start_time', datetime.now()) + + # Build explanation steps + steps = [] + step_number = 1 + + # Step 1: Question Analysis + steps.append(self._explain_question_analysis( + step_number, question, question_components + )) + step_number += 1 + + # Step 2: Ontology Matching + steps.append(self._explain_ontology_matching( + step_number, question_components, ontology_subsets + )) + step_number += 1 + + # Step 3: Query Generation + steps.append(self._explain_query_generation( + step_number, generated_query, processing_metadata + )) + step_number += 1 + + # Step 4: Query Execution + steps.append(self._explain_query_execution( + step_number, generated_query, query_results, processing_metadata + )) + step_number += 1 + + # Step 5: Answer Generation + steps.append(self._explain_answer_generation( + step_number, query_results, final_answer, processing_metadata + )) + + # Build reasoning chain + reasoning_chain = self._build_reasoning_chain( + question_components, ontology_subsets, generated_query, processing_metadata + ) + + # Calculate overall confidence + confidence_score = self._calculate_explanation_confidence( + question_components, query_results, processing_metadata + ) + + # Generate user-friendly explanation + user_friendly_explanation = self._generate_user_friendly_explanation( + question, question_components, ontology_subsets, final_answer + ) + + # Calculate total duration + total_duration = processing_metadata.get('total_duration_ms', 0) + + return QueryExplanation( + query_id=query_id, + original_question=question, + processing_steps=steps, + final_answer=final_answer, + confidence_score=confidence_score, + total_duration_ms=total_duration, + ontologies_used=[subset.metadata.get('ontology_id', 'unknown') for subset in ontology_subsets], + backend_used=processing_metadata.get('backend_used', 'unknown'), + reasoning_chain=reasoning_chain, + technical_details=self._extract_technical_details(processing_metadata), + user_friendly_explanation=user_friendly_explanation + ) + + def _explain_question_analysis(self, + step_number: int, + question: str, + question_components: QuestionComponents) -> ExplanationStep: + """Explain question analysis step.""" + template = self.step_templates['question_analysis'][self.explanation_level] + + if self.explanation_level == 'basic': + explanation = template + elif self.explanation_level == 'detailed': + explanation = template.format( + question=question, + question_type=question_components.question_type.value.replace('_', ' '), + entities=', '.join(question_components.entities[:3]) + ) + else: # technical + explanation = template.format( + question_type=question_components.question_type.value, + entities=question_components.entities, + keywords=question_components.keywords, + answer_type=question_components.expected_answer_type + ) + + return ExplanationStep( + step_number=step_number, + component="question_analyzer", + operation="analyze_question", + input_data={"question": question}, + output_data={ + "question_type": question_components.question_type.value, + "entities": question_components.entities, + "keywords": question_components.keywords + }, + explanation=explanation, + duration_ms=0.0, # Would be tracked in actual implementation + success=True + ) + + def _explain_ontology_matching(self, + step_number: int, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset]) -> ExplanationStep: + """Explain ontology matching step.""" + template = self.step_templates['ontology_matching'][self.explanation_level] + + total_elements = sum( + len(subset.classes) + len(subset.object_properties) + len(subset.datatype_properties) + for subset in ontology_subsets + ) + + if self.explanation_level == 'basic': + explanation = template.format( + entities=', '.join(question_components.entities[:3]) + ) + elif self.explanation_level == 'detailed': + explanation = template.format( + ontology_count=len(ontology_subsets), + selected_elements=total_elements + ) + else: # technical + total_classes = sum(len(subset.classes) for subset in ontology_subsets) + total_properties = sum( + len(subset.object_properties) + len(subset.datatype_properties) + for subset in ontology_subsets + ) + ontology_names = [subset.metadata.get('ontology_id', 'unknown') for subset in ontology_subsets] + + explanation = template.format( + classes=total_classes, + properties=total_properties, + ontologies=', '.join(ontology_names) + ) + + return ExplanationStep( + step_number=step_number, + component="ontology_matcher", + operation="select_relevant_subset", + input_data={"entities": question_components.entities}, + output_data={ + "ontology_count": len(ontology_subsets), + "total_elements": total_elements + }, + explanation=explanation, + duration_ms=0.0, + success=True + ) + + def _explain_query_generation(self, + step_number: int, + generated_query: Union[SPARQLQuery, CypherQuery], + metadata: Dict[str, Any]) -> ExplanationStep: + """Explain query generation step.""" + template = self.step_templates['query_generation'][self.explanation_level] + + query_language = "SPARQL" if isinstance(generated_query, SPARQLQuery) else "Cypher" + backend = metadata.get('backend_used', 'unknown') + + if self.explanation_level == 'basic': + explanation = template + elif self.explanation_level == 'detailed': + explanation = template.format( + query_type=generated_query.query_type, + query_language=query_language, + backend=backend + ) + else: # technical + explanation = template.format( + query_language=query_language, + variables=len(generated_query.variables), + complexity=f"{generated_query.complexity_score:.2f}" + ) + + return ExplanationStep( + step_number=step_number, + component="query_generator", + operation="generate_query", + input_data={"query_type": generated_query.query_type}, + output_data={ + "query_language": query_language, + "variables": generated_query.variables, + "complexity": generated_query.complexity_score + }, + explanation=explanation, + duration_ms=0.0, + success=True, + metadata={"generated_query": generated_query.query} + ) + + def _explain_query_execution(self, + step_number: int, + generated_query: Union[SPARQLQuery, CypherQuery], + query_results: Union[SPARQLResult, CypherResult], + metadata: Dict[str, Any]) -> ExplanationStep: + """Explain query execution step.""" + template = self.step_templates['query_execution'][self.explanation_level] + + backend = metadata.get('backend_used', 'unknown') + duration = getattr(query_results, 'execution_time', 0) * 1000 # Convert to ms + + if isinstance(query_results, SPARQLResult): + result_count = len(query_results.bindings) + else: # CypherResult + result_count = len(query_results.records) + + if self.explanation_level == 'basic': + explanation = template.format(result_count=result_count) + elif self.explanation_level == 'detailed': + explanation = template.format( + backend=backend, + result_count=result_count, + duration=f"{duration:.1f}" + ) + else: # technical + explanation = template.format( + backend=backend, + result_count=result_count, + duration=f"{duration:.1f}" + ) + + return ExplanationStep( + step_number=step_number, + component="query_executor", + operation="execute_query", + input_data={"query": generated_query.query}, + output_data={ + "result_count": result_count, + "execution_time_ms": duration + }, + explanation=explanation, + duration_ms=duration, + success=result_count >= 0 + ) + + def _explain_answer_generation(self, + step_number: int, + query_results: Union[SPARQLResult, CypherResult], + final_answer: str, + metadata: Dict[str, Any]) -> ExplanationStep: + """Explain answer generation step.""" + template = self.step_templates['answer_generation'][self.explanation_level] + + if isinstance(query_results, SPARQLResult): + result_count = len(query_results.bindings) + else: # CypherResult + result_count = len(query_results.records) + + confidence = metadata.get('answer_confidence', 0.8) * 100 # Convert to percentage + + if self.explanation_level == 'basic': + explanation = template + elif self.explanation_level == 'detailed': + explanation = template.format( + result_count=result_count, + confidence=f"{confidence:.0f}" + ) + else: # technical + generation_method = metadata.get('generation_method', 'template_based') + explanation = template.format( + result_count=result_count, + generation_method=generation_method, + confidence=f"{confidence:.1f}" + ) + + return ExplanationStep( + step_number=step_number, + component="answer_generator", + operation="generate_answer", + input_data={"result_count": result_count}, + output_data={ + "answer": final_answer, + "confidence": confidence / 100 + }, + explanation=explanation, + duration_ms=0.0, + success=bool(final_answer) + ) + + def _build_reasoning_chain(self, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset], + generated_query: Union[SPARQLQuery, CypherQuery], + metadata: Dict[str, Any]) -> List[str]: + """Build reasoning chain explaining the decision process.""" + reasoning = [] + + # Entity identification reasoning + if question_components.entities: + for entity in question_components.entities[:3]: + reasoning.append( + self.reasoning_templates['entity_identification'].format(entity=entity) + ) + + # Ontology selection reasoning + if ontology_subsets: + primary_ontology = ontology_subsets[0] + ontology_id = primary_ontology.metadata.get('ontology_id', 'primary') + concepts = list(primary_ontology.classes.keys())[:3] + reasoning.append( + self.reasoning_templates['ontology_selection'].format( + ontology=ontology_id, + concepts=', '.join(concepts) + ) + ) + + # Query strategy reasoning + query_language = "SPARQL" if isinstance(generated_query, SPARQLQuery) else "Cypher" + if question_components.question_type == QuestionType.AGGREGATION: + strategy = "aggregation" + reason = "you asked for a count or sum" + elif question_components.question_type == QuestionType.BOOLEAN: + strategy = "boolean" + reason = "you asked a yes/no question" + else: + strategy = "retrieval" + reason = "you asked for specific information" + + reasoning.append( + self.reasoning_templates['query_strategy'].format( + strategy=strategy, + reason=reason + ) + ) + + # Confidence assessment + confidence = metadata.get('answer_confidence', 0.8) * 100 + if confidence > 90: + confidence_reason = "the query matched well with available data" + elif confidence > 70: + confidence_reason = "the query found relevant information with some uncertainty" + else: + confidence_reason = "the available data partially matches your question" + + reasoning.append( + self.reasoning_templates['confidence_assessment'].format( + confidence=f"{confidence:.0f}", + reasoning=confidence_reason + ) + ) + + return reasoning[:self.max_reasoning_steps] + + def _calculate_explanation_confidence(self, + question_components: QuestionComponents, + query_results: Union[SPARQLResult, CypherResult], + metadata: Dict[str, Any]) -> float: + """Calculate confidence score for the explanation.""" + confidence = 0.8 # Base confidence + + # Adjust based on result count + if isinstance(query_results, SPARQLResult): + result_count = len(query_results.bindings) + else: + result_count = len(query_results.records) + + if result_count > 0: + confidence += 0.1 + if result_count > 5: + confidence += 0.05 + + # Adjust based on question complexity + if len(question_components.entities) > 0: + confidence += 0.05 + + # Adjust based on processing success + if metadata.get('all_steps_successful', True): + confidence += 0.05 + + return min(confidence, 1.0) + + def _generate_user_friendly_explanation(self, + question: str, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset], + final_answer: str) -> str: + """Generate user-friendly explanation of the process.""" + explanation_parts = [] + + # Introduction + explanation_parts.append(f"To answer your question '{question}', I followed these steps:") + + # Process summary + if question_components.question_type == QuestionType.AGGREGATION: + explanation_parts.append("1. I recognized this as a counting or aggregation question") + elif question_components.question_type == QuestionType.BOOLEAN: + explanation_parts.append("1. I recognized this as a yes/no question") + else: + explanation_parts.append("1. I analyzed your question to understand what information you need") + + # Ontology usage + if ontology_subsets: + ontology_count = len(ontology_subsets) + if ontology_count == 1: + explanation_parts.append("2. I searched through the relevant knowledge base") + else: + explanation_parts.append(f"2. I searched through {ontology_count} knowledge bases") + + # Result processing + explanation_parts.append("3. I found the relevant information and generated your answer") + + # Conclusion + explanation_parts.append(f"The answer is: {final_answer}") + + return " ".join(explanation_parts) + + def _extract_technical_details(self, metadata: Dict[str, Any]) -> Dict[str, Any]: + """Extract technical details for debugging and optimization.""" + return { + 'query_optimization': metadata.get('query_optimization', {}), + 'backend_performance': metadata.get('backend_performance', {}), + 'cache_usage': metadata.get('cache_usage', {}), + 'error_handling': metadata.get('error_handling', {}), + 'routing_decision': metadata.get('routing_decision', {}) + } + + def format_explanation_for_display(self, + explanation: QueryExplanation, + format_type: str = 'html') -> str: + """Format explanation for display. + + Args: + explanation: Query explanation + format_type: Output format ('html', 'markdown', 'text') + + Returns: + Formatted explanation + """ + if format_type == 'html': + return self._format_html_explanation(explanation) + elif format_type == 'markdown': + return self._format_markdown_explanation(explanation) + else: + return self._format_text_explanation(explanation) + + def _format_html_explanation(self, explanation: QueryExplanation) -> str: + """Format explanation as HTML.""" + html_parts = [ + f"

Query Explanation: {explanation.query_id}

", + f"

Question: {explanation.original_question}

", + f"

Answer: {explanation.final_answer}

", + f"

Confidence: {explanation.confidence_score:.1%}

", + "

Processing Steps:

", + "
    " + ] + + for step in explanation.processing_steps: + html_parts.append(f"
  1. {step.component}: {step.explanation}
  2. ") + + html_parts.extend([ + "
", + "

Reasoning:

", + "") + + return "".join(html_parts) + + def _format_markdown_explanation(self, explanation: QueryExplanation) -> str: + """Format explanation as Markdown.""" + md_parts = [ + f"## Query Explanation: {explanation.query_id}", + f"**Question:** {explanation.original_question}", + f"**Answer:** {explanation.final_answer}", + f"**Confidence:** {explanation.confidence_score:.1%}", + "", + "### Processing Steps:", + "" + ] + + for i, step in enumerate(explanation.processing_steps, 1): + md_parts.append(f"{i}. **{step.component}**: {step.explanation}") + + md_parts.extend([ + "", + "### Reasoning:", + "" + ]) + + for reasoning in explanation.reasoning_chain: + md_parts.append(f"- {reasoning}") + + return "\n".join(md_parts) + + def _format_text_explanation(self, explanation: QueryExplanation) -> str: + """Format explanation as plain text.""" + text_parts = [ + f"Query Explanation: {explanation.query_id}", + f"Question: {explanation.original_question}", + f"Answer: {explanation.final_answer}", + f"Confidence: {explanation.confidence_score:.1%}", + "", + "Processing Steps:", + ] + + for i, step in enumerate(explanation.processing_steps, 1): + text_parts.append(f" {i}. {step.component}: {step.explanation}") + + text_parts.extend([ + "", + "Reasoning:", + ]) + + for reasoning in explanation.reasoning_chain: + text_parts.append(f" - {reasoning}") + + return "\n".join(text_parts) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/query_optimizer.py b/trustgraph-flow/trustgraph/query/ontology/query_optimizer.py new file mode 100644 index 00000000..5d8f36ec --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/query_optimizer.py @@ -0,0 +1,519 @@ +""" +Query optimization module for OntoRAG. +Optimizes SPARQL and Cypher queries for better performance and accuracy. +""" + +import logging +from typing import Dict, Any, List, Optional, Union, Tuple +from dataclasses import dataclass +from enum import Enum +import re + +from .question_analyzer import QuestionComponents, QuestionType +from .ontology_matcher import QueryOntologySubset +from .sparql_generator import SPARQLQuery +from .cypher_generator import CypherQuery + +logger = logging.getLogger(__name__) + + +class OptimizationStrategy(Enum): + """Query optimization strategies.""" + PERFORMANCE = "performance" + ACCURACY = "accuracy" + BALANCED = "balanced" + + +@dataclass +class OptimizationHint: + """Optimization hint for query processing.""" + strategy: OptimizationStrategy + max_results: Optional[int] = None + timeout_seconds: Optional[int] = None + use_indices: bool = True + enable_parallel: bool = False + cache_results: bool = True + + +@dataclass +class QueryPlan: + """Query execution plan with optimization metadata.""" + original_query: str + optimized_query: str + estimated_cost: float + optimization_notes: List[str] + index_hints: List[str] + execution_order: List[str] + + +class QueryOptimizer: + """Optimizes SPARQL and Cypher queries for performance and accuracy.""" + + def __init__(self, config: Dict[str, Any] = None): + """Initialize query optimizer. + + Args: + config: Optimizer configuration + """ + self.config = config or {} + self.default_strategy = OptimizationStrategy( + self.config.get('default_strategy', 'balanced') + ) + self.max_query_complexity = self.config.get('max_query_complexity', 10) + self.enable_query_rewriting = self.config.get('enable_query_rewriting', True) + + # Performance thresholds + self.large_result_threshold = self.config.get('large_result_threshold', 1000) + self.complex_join_threshold = self.config.get('complex_join_threshold', 3) + + def optimize_sparql(self, + sparql_query: SPARQLQuery, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + optimization_hint: Optional[OptimizationHint] = None) -> Tuple[SPARQLQuery, QueryPlan]: + """Optimize SPARQL query. + + Args: + sparql_query: Original SPARQL query + question_components: Question analysis + ontology_subset: Ontology subset + optimization_hint: Optimization hints + + Returns: + Optimized SPARQL query and execution plan + """ + hint = optimization_hint or OptimizationHint(strategy=self.default_strategy) + + optimized_query = sparql_query.query + optimization_notes = [] + index_hints = [] + execution_order = [] + + # Apply optimizations based on strategy + if hint.strategy in [OptimizationStrategy.PERFORMANCE, OptimizationStrategy.BALANCED]: + optimized_query, perf_notes, perf_hints = self._optimize_sparql_performance( + optimized_query, question_components, ontology_subset, hint + ) + optimization_notes.extend(perf_notes) + index_hints.extend(perf_hints) + + if hint.strategy in [OptimizationStrategy.ACCURACY, OptimizationStrategy.BALANCED]: + optimized_query, acc_notes = self._optimize_sparql_accuracy( + optimized_query, question_components, ontology_subset + ) + optimization_notes.extend(acc_notes) + + # Estimate query cost + estimated_cost = self._estimate_sparql_cost(optimized_query, ontology_subset) + + # Build execution plan + query_plan = QueryPlan( + original_query=sparql_query.query, + optimized_query=optimized_query, + estimated_cost=estimated_cost, + optimization_notes=optimization_notes, + index_hints=index_hints, + execution_order=execution_order + ) + + # Create optimized query object + optimized_sparql = SPARQLQuery( + query=optimized_query, + variables=sparql_query.variables, + query_type=sparql_query.query_type, + explanation=f"Optimized: {sparql_query.explanation}", + complexity_score=min(sparql_query.complexity_score * 0.8, 1.0) # Assume optimization reduces complexity + ) + + return optimized_sparql, query_plan + + def optimize_cypher(self, + cypher_query: CypherQuery, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + optimization_hint: Optional[OptimizationHint] = None) -> Tuple[CypherQuery, QueryPlan]: + """Optimize Cypher query. + + Args: + cypher_query: Original Cypher query + question_components: Question analysis + ontology_subset: Ontology subset + optimization_hint: Optimization hints + + Returns: + Optimized Cypher query and execution plan + """ + hint = optimization_hint or OptimizationHint(strategy=self.default_strategy) + + optimized_query = cypher_query.query + optimization_notes = [] + index_hints = [] + execution_order = [] + + # Apply optimizations based on strategy + if hint.strategy in [OptimizationStrategy.PERFORMANCE, OptimizationStrategy.BALANCED]: + optimized_query, perf_notes, perf_hints = self._optimize_cypher_performance( + optimized_query, question_components, ontology_subset, hint + ) + optimization_notes.extend(perf_notes) + index_hints.extend(perf_hints) + + if hint.strategy in [OptimizationStrategy.ACCURACY, OptimizationStrategy.BALANCED]: + optimized_query, acc_notes = self._optimize_cypher_accuracy( + optimized_query, question_components, ontology_subset + ) + optimization_notes.extend(acc_notes) + + # Estimate query cost + estimated_cost = self._estimate_cypher_cost(optimized_query, ontology_subset) + + # Build execution plan + query_plan = QueryPlan( + original_query=cypher_query.query, + optimized_query=optimized_query, + estimated_cost=estimated_cost, + optimization_notes=optimization_notes, + index_hints=index_hints, + execution_order=execution_order + ) + + # Create optimized query object + optimized_cypher = CypherQuery( + query=optimized_query, + variables=cypher_query.variables, + query_type=cypher_query.query_type, + explanation=f"Optimized: {cypher_query.explanation}", + complexity_score=min(cypher_query.complexity_score * 0.8, 1.0) + ) + + return optimized_cypher, query_plan + + def _optimize_sparql_performance(self, + query: str, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + hint: OptimizationHint) -> Tuple[str, List[str], List[str]]: + """Apply performance optimizations to SPARQL query. + + Args: + query: SPARQL query string + question_components: Question analysis + ontology_subset: Ontology subset + hint: Optimization hints + + Returns: + Optimized query, optimization notes, and index hints + """ + optimized = query + notes = [] + index_hints = [] + + # Add LIMIT if not present and large results expected + if hint.max_results and 'LIMIT' not in optimized.upper(): + optimized = f"{optimized.rstrip()}\nLIMIT {hint.max_results}" + notes.append(f"Added LIMIT {hint.max_results} to prevent large result sets") + + # Optimize OPTIONAL clauses (move to end) + optional_pattern = re.compile(r'OPTIONAL\s*\{[^}]+\}', re.IGNORECASE | re.DOTALL) + optionals = optional_pattern.findall(optimized) + if optionals: + # Remove optionals from current position + for optional in optionals: + optimized = optimized.replace(optional, '') + + # Add them at the end (before ORDER BY/LIMIT) + insert_point = optimized.rfind('ORDER BY') + if insert_point == -1: + insert_point = optimized.rfind('LIMIT') + if insert_point == -1: + insert_point = len(optimized.rstrip()) + + for optional in optionals: + optimized = optimized[:insert_point] + f"\n {optional}" + optimized[insert_point:] + + notes.append("Moved OPTIONAL clauses to end for better performance") + + # Add index hints for Cassandra + if 'WHERE' in optimized.upper(): + # Suggest indices for common patterns + if '?subject rdf:type' in optimized: + index_hints.append("type_index") + if 'rdfs:subClassOf' in optimized: + index_hints.append("hierarchy_index") + + # Optimize FILTER clauses (move closer to variable bindings) + filter_pattern = re.compile(r'FILTER\s*\([^)]+\)', re.IGNORECASE) + filters = filter_pattern.findall(optimized) + if filters: + notes.append("FILTER clauses present - ensure they're positioned optimally") + + return optimized, notes, index_hints + + def _optimize_sparql_accuracy(self, + query: str, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> Tuple[str, List[str]]: + """Apply accuracy optimizations to SPARQL query. + + Args: + query: SPARQL query string + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Optimized query and optimization notes + """ + optimized = query + notes = [] + + # Add missing namespace checks + if question_components.question_type == QuestionType.RETRIEVAL: + # Ensure we're not mixing namespaces inappropriately + if 'http://' in optimized and '?' in optimized: + notes.append("Verified namespace consistency for accuracy") + + # Add type constraints for better precision + if '?entity' in optimized and 'rdf:type' not in optimized: + # Find a good insertion point + where_clause = re.search(r'WHERE\s*\{(.+)\}', optimized, re.DOTALL | re.IGNORECASE) + if where_clause and ontology_subset.classes: + # Add type constraint for the most relevant class + main_class = list(ontology_subset.classes.keys())[0] + type_constraint = f"\n ?entity rdf:type :{main_class} ." + + # Insert after the WHERE { + where_start = where_clause.start(1) + optimized = optimized[:where_start] + type_constraint + optimized[where_start:] + notes.append(f"Added type constraint for {main_class} to improve accuracy") + + # Add DISTINCT if not present for retrieval queries + if (question_components.question_type == QuestionType.RETRIEVAL and + 'DISTINCT' not in optimized.upper() and + 'SELECT' in optimized.upper()): + optimized = optimized.replace('SELECT ', 'SELECT DISTINCT ', 1) + notes.append("Added DISTINCT to eliminate duplicate results") + + return optimized, notes + + def _optimize_cypher_performance(self, + query: str, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset, + hint: OptimizationHint) -> Tuple[str, List[str], List[str]]: + """Apply performance optimizations to Cypher query. + + Args: + query: Cypher query string + question_components: Question analysis + ontology_subset: Ontology subset + hint: Optimization hints + + Returns: + Optimized query, optimization notes, and index hints + """ + optimized = query + notes = [] + index_hints = [] + + # Add LIMIT if not present + if hint.max_results and 'LIMIT' not in optimized.upper(): + optimized = f"{optimized.rstrip()}\nLIMIT {hint.max_results}" + notes.append(f"Added LIMIT {hint.max_results} to prevent large result sets") + + # Use parameters for literals to enable query plan caching + if "'" in optimized or '"' in optimized: + notes.append("Consider using parameters for literal values to enable query plan caching") + + # Suggest indices based on query patterns + if 'MATCH (n:' in optimized: + label_match = re.search(r'MATCH \(n:(\w+)\)', optimized) + if label_match: + label = label_match.group(1) + index_hints.append(f"node_label_index:{label}") + + if 'WHERE' in optimized.upper() and '.' in optimized: + # Property access patterns + property_pattern = re.compile(r'\.(\w+)', re.IGNORECASE) + properties = property_pattern.findall(optimized) + for prop in set(properties): + index_hints.append(f"property_index:{prop}") + + # Optimize relationship traversals + if '-[' in optimized and '*' in optimized: + notes.append("Variable length path detected - consider adding relationship type filters") + + # Early filtering optimization + if 'WHERE' in optimized.upper(): + # Move WHERE clauses closer to MATCH clauses + notes.append("WHERE clauses present - ensure early filtering for performance") + + return optimized, notes, index_hints + + def _optimize_cypher_accuracy(self, + query: str, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> Tuple[str, List[str]]: + """Apply accuracy optimizations to Cypher query. + + Args: + query: Cypher query string + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Optimized query and optimization notes + """ + optimized = query + notes = [] + + # Add DISTINCT if not present for retrieval queries + if (question_components.question_type == QuestionType.RETRIEVAL and + 'DISTINCT' not in optimized.upper() and + 'RETURN' in optimized.upper()): + optimized = re.sub(r'RETURN\s+', 'RETURN DISTINCT ', optimized, count=1, flags=re.IGNORECASE) + notes.append("Added DISTINCT to eliminate duplicate results") + + # Ensure proper relationship direction + if '-[' in optimized and question_components.relationships: + notes.append("Verified relationship directions for semantic accuracy") + + # Add null checks for optional properties + if '?' in optimized or 'OPTIONAL' in optimized.upper(): + notes.append("Consider adding null checks for optional properties") + + return optimized, notes + + def _estimate_sparql_cost(self, query: str, ontology_subset: QueryOntologySubset) -> float: + """Estimate execution cost for SPARQL query. + + Args: + query: SPARQL query string + ontology_subset: Ontology subset + + Returns: + Estimated cost (0.0 to 1.0) + """ + cost = 0.0 + + # Basic query complexity + cost += len(query.split('\n')) * 0.01 + + # Join complexity + triple_patterns = len(re.findall(r'\?\w+\s+\?\w+\s+\?\w+', query)) + cost += triple_patterns * 0.1 + + # OPTIONAL clauses + optional_count = len(re.findall(r'OPTIONAL', query, re.IGNORECASE)) + cost += optional_count * 0.15 + + # FILTER clauses + filter_count = len(re.findall(r'FILTER', query, re.IGNORECASE)) + cost += filter_count * 0.1 + + # Property paths + path_count = len(re.findall(r'\*|\+', query)) + cost += path_count * 0.2 + + # Ontology subset size impact + total_elements = (len(ontology_subset.classes) + + len(ontology_subset.object_properties) + + len(ontology_subset.datatype_properties)) + cost += (total_elements / 100.0) * 0.1 + + return min(cost, 1.0) + + def _estimate_cypher_cost(self, query: str, ontology_subset: QueryOntologySubset) -> float: + """Estimate execution cost for Cypher query. + + Args: + query: Cypher query string + ontology_subset: Ontology subset + + Returns: + Estimated cost (0.0 to 1.0) + """ + cost = 0.0 + + # Basic query complexity + cost += len(query.split('\n')) * 0.01 + + # Pattern complexity + match_count = len(re.findall(r'MATCH', query, re.IGNORECASE)) + cost += match_count * 0.1 + + # Relationship traversals + rel_count = len(re.findall(r'-\[.*?\]-', query)) + cost += rel_count * 0.1 + + # Variable length paths + var_path_count = len(re.findall(r'\*\d*\.\.', query)) + cost += var_path_count * 0.3 + + # WHERE clauses + where_count = len(re.findall(r'WHERE', query, re.IGNORECASE)) + cost += where_count * 0.05 + + # Aggregation functions + agg_count = len(re.findall(r'COUNT|SUM|AVG|MIN|MAX', query, re.IGNORECASE)) + cost += agg_count * 0.1 + + # Ontology subset size impact + total_elements = (len(ontology_subset.classes) + + len(ontology_subset.object_properties) + + len(ontology_subset.datatype_properties)) + cost += (total_elements / 100.0) * 0.1 + + return min(cost, 1.0) + + def should_use_cache(self, + query: str, + question_components: QuestionComponents, + optimization_hint: OptimizationHint) -> bool: + """Determine if query results should be cached. + + Args: + query: Query string + question_components: Question analysis + optimization_hint: Optimization hints + + Returns: + True if results should be cached + """ + if not optimization_hint.cache_results: + return False + + # Cache simple retrieval and factual queries + if question_components.question_type in [QuestionType.RETRIEVAL, QuestionType.FACTUAL]: + return True + + # Cache expensive aggregation queries + if (question_components.question_type == QuestionType.AGGREGATION and + ('COUNT' in query.upper() or 'SUM' in query.upper())): + return True + + # Don't cache real-time or time-sensitive queries + if any(keyword in question_components.original_question.lower() + for keyword in ['now', 'current', 'latest', 'recent']): + return False + + return False + + def get_cache_key(self, + query: str, + ontology_subset: QueryOntologySubset) -> str: + """Generate cache key for query. + + Args: + query: Query string + ontology_subset: Ontology subset + + Returns: + Cache key string + """ + import hashlib + + # Create stable representation + ontology_repr = f"{sorted(ontology_subset.classes.keys())}-{sorted(ontology_subset.object_properties.keys())}" + combined = f"{query.strip()}-{ontology_repr}" + + return hashlib.md5(combined.encode()).hexdigest() \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/query_service.py b/trustgraph-flow/trustgraph/query/ontology/query_service.py new file mode 100644 index 00000000..ec7884ed --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/query_service.py @@ -0,0 +1,438 @@ +""" +Main OntoRAG query service. +Orchestrates question analysis, ontology matching, query generation, execution, and answer generation. +""" + +import logging +from typing import Dict, Any, List, Optional, Union +from dataclasses import dataclass +from datetime import datetime + +from ....flow.flow_processor import FlowProcessor +from ....tables.config import ConfigTableStore +from ...extract.kg.ontology.ontology_loader import OntologyLoader +from ...extract.kg.ontology.vector_store import InMemoryVectorStore + +from .question_analyzer import QuestionAnalyzer, QuestionComponents +from .ontology_matcher import OntologyMatcher, QueryOntologySubset +from .backend_router import BackendRouter, QueryRoute, BackendType +from .sparql_generator import SPARQLGenerator, SPARQLQuery +from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult +from .cypher_generator import CypherGenerator, CypherQuery +from .cypher_executor import CypherExecutor, CypherResult +from .answer_generator import AnswerGenerator, GeneratedAnswer + +logger = logging.getLogger(__name__) + + +@dataclass +class QueryRequest: + """Query request from user.""" + question: str + context: Optional[str] = None + ontology_hint: Optional[str] = None + max_results: int = 10 + confidence_threshold: float = 0.7 + + +@dataclass +class QueryResponse: + """Complete query response.""" + answer: str + confidence: float + execution_time: float + question_analysis: QuestionComponents + ontology_subsets: List[QueryOntologySubset] + query_route: QueryRoute + generated_query: Union[SPARQLQuery, CypherQuery] + raw_results: Union[SPARQLResult, CypherResult] + supporting_facts: List[str] + metadata: Dict[str, Any] + + +class OntoRAGQueryService(FlowProcessor): + """Main OntoRAG query service orchestrating all components.""" + + def __init__(self, config: Dict[str, Any]): + """Initialize OntoRAG query service. + + Args: + config: Service configuration + """ + super().__init__(config) + self.config = config + + # Initialize components + self.config_store = None + self.ontology_loader = None + self.vector_store = None + self.question_analyzer = None + self.ontology_matcher = None + self.backend_router = None + self.sparql_generator = None + self.sparql_engine = None + self.cypher_generator = None + self.cypher_executor = None + self.answer_generator = None + + # Cache for loaded ontologies + self.ontology_cache = {} + + async def init(self): + """Initialize all components.""" + await super().init() + + # Initialize configuration store + self.config_store = ConfigTableStore(self.config.get('config_store', {})) + + # Initialize ontology components + self.ontology_loader = OntologyLoader(self.config_store) + + # Initialize vector store + vector_config = self.config.get('vector_store', {}) + self.vector_store = InMemoryVectorStore.create( + store_type=vector_config.get('type', 'numpy'), + dimension=vector_config.get('dimension', 384), + similarity_threshold=vector_config.get('similarity_threshold', 0.7) + ) + + # Initialize question analyzer + analyzer_config = self.config.get('question_analyzer', {}) + self.question_analyzer = QuestionAnalyzer( + prompt_service=self.prompt_service, + config=analyzer_config + ) + + # Initialize ontology matcher + matcher_config = self.config.get('ontology_matcher', {}) + self.ontology_matcher = OntologyMatcher( + vector_store=self.vector_store, + embedding_service=self.embedding_service, + config=matcher_config + ) + + # Initialize backend router + router_config = self.config.get('backend_router', {}) + self.backend_router = BackendRouter(router_config) + + # Initialize query generators + self.sparql_generator = SPARQLGenerator(prompt_service=self.prompt_service) + self.cypher_generator = CypherGenerator(prompt_service=self.prompt_service) + + # Initialize executors + sparql_config = self.config.get('sparql_executor', {}) + if self.backend_router.is_backend_enabled(BackendType.CASSANDRA): + cassandra_config = self.backend_router.get_backend_config(BackendType.CASSANDRA) + if cassandra_config: + self.sparql_engine = SPARQLCassandraEngine(cassandra_config) + await self.sparql_engine.initialize() + + cypher_config = self.config.get('cypher_executor', {}) + enabled_graph_backends = [ + bt for bt in [BackendType.NEO4J, BackendType.MEMGRAPH, BackendType.FALKORDB] + if self.backend_router.is_backend_enabled(bt) + ] + if enabled_graph_backends: + self.cypher_executor = CypherExecutor(cypher_config) + await self.cypher_executor.initialize() + + # Initialize answer generator + self.answer_generator = AnswerGenerator(prompt_service=self.prompt_service) + + logger.info("OntoRAG query service initialized") + + async def process(self, request: QueryRequest) -> QueryResponse: + """Process a natural language query. + + Args: + request: Query request + + Returns: + Complete query response + """ + start_time = datetime.now() + + try: + logger.info(f"Processing query: {request.question}") + + # Step 1: Analyze question + question_components = await self.question_analyzer.analyze_question( + request.question, context=request.context + ) + logger.debug(f"Question analysis: {question_components.question_type}") + + # Step 2: Load and match ontologies + ontology_subsets = await self._load_and_match_ontologies( + question_components, request.ontology_hint + ) + logger.debug(f"Found {len(ontology_subsets)} relevant ontology subsets") + + # Step 3: Route to appropriate backend + query_route = self.backend_router.route_query( + question_components, ontology_subsets + ) + logger.debug(f"Routed to {query_route.backend_type.value} backend") + + # Step 4: Generate and execute query + if query_route.query_language == 'sparql': + query_results = await self._execute_sparql_path( + question_components, ontology_subsets, query_route + ) + else: # cypher + query_results = await self._execute_cypher_path( + question_components, ontology_subsets, query_route + ) + + # Step 5: Generate natural language answer + generated_answer = await self.answer_generator.generate_answer( + question_components, + query_results['raw_results'], + ontology_subsets[0] if ontology_subsets else None, + query_route.backend_type.value + ) + + # Build response + execution_time = (datetime.now() - start_time).total_seconds() + + response = QueryResponse( + answer=generated_answer.answer, + confidence=min(query_route.confidence, generated_answer.metadata.confidence), + execution_time=execution_time, + question_analysis=question_components, + ontology_subsets=ontology_subsets, + query_route=query_route, + generated_query=query_results['generated_query'], + raw_results=query_results['raw_results'], + supporting_facts=generated_answer.supporting_facts, + metadata={ + 'backend_used': query_route.backend_type.value, + 'query_language': query_route.query_language, + 'ontology_count': len(ontology_subsets), + 'result_count': generated_answer.metadata.result_count, + 'routing_reasoning': query_route.reasoning, + 'generation_time': generated_answer.generation_time + } + ) + + logger.info(f"Query processed successfully in {execution_time:.2f}s") + return response + + except Exception as e: + logger.error(f"Query processing failed: {e}") + execution_time = (datetime.now() - start_time).total_seconds() + + # Return error response + return QueryResponse( + answer=f"I encountered an error processing your query: {str(e)}", + confidence=0.0, + execution_time=execution_time, + question_analysis=QuestionComponents( + original_question=request.question, + normalized_question=request.question, + question_type=None, + entities=[], keywords=[], relationships=[], constraints=[], + aggregations=[], expected_answer_type="unknown" + ), + ontology_subsets=[], + query_route=None, + generated_query=None, + raw_results=None, + supporting_facts=[], + metadata={'error': str(e), 'execution_time': execution_time} + ) + + async def _load_and_match_ontologies(self, + question_components: QuestionComponents, + ontology_hint: Optional[str] = None) -> List[QueryOntologySubset]: + """Load ontologies and find relevant subsets. + + Args: + question_components: Analyzed question + ontology_hint: Optional ontology hint + + Returns: + List of relevant ontology subsets + """ + try: + # Load available ontologies + if ontology_hint: + # Load specific ontology + ontologies = [await self.ontology_loader.load_ontology(ontology_hint)] + else: + # Load all available ontologies + available_ontologies = await self.ontology_loader.list_available_ontologies() + ontologies = [] + for ontology_id in available_ontologies[:5]: # Limit to 5 for performance + try: + ontology = await self.ontology_loader.load_ontology(ontology_id) + ontologies.append(ontology) + except Exception as e: + logger.warning(f"Failed to load ontology {ontology_id}: {e}") + + if not ontologies: + logger.warning("No ontologies loaded") + return [] + + # Extract relevant subsets + ontology_subsets = [] + for ontology in ontologies: + subset = await self.ontology_matcher.select_relevant_subset( + question_components, ontology + ) + if subset and (subset.classes or subset.object_properties or subset.datatype_properties): + ontology_subsets.append(subset) + + return ontology_subsets + + except Exception as e: + logger.error(f"Failed to load and match ontologies: {e}") + return [] + + async def _execute_sparql_path(self, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset], + query_route: QueryRoute) -> Dict[str, Any]: + """Execute SPARQL query path. + + Args: + question_components: Question analysis + ontology_subsets: Ontology subsets + query_route: Query route + + Returns: + Query execution results + """ + if not self.sparql_engine: + raise RuntimeError("SPARQL engine not initialized") + + # Generate SPARQL query + primary_subset = ontology_subsets[0] if ontology_subsets else None + sparql_query = await self.sparql_generator.generate_sparql( + question_components, primary_subset + ) + + logger.debug(f"Generated SPARQL: {sparql_query.query}") + + # Execute query + sparql_results = self.sparql_engine.execute_sparql(sparql_query.query) + + return { + 'generated_query': sparql_query, + 'raw_results': sparql_results + } + + async def _execute_cypher_path(self, + question_components: QuestionComponents, + ontology_subsets: List[QueryOntologySubset], + query_route: QueryRoute) -> Dict[str, Any]: + """Execute Cypher query path. + + Args: + question_components: Question analysis + ontology_subsets: Ontology subsets + query_route: Query route + + Returns: + Query execution results + """ + if not self.cypher_executor: + raise RuntimeError("Cypher executor not initialized") + + # Generate Cypher query + primary_subset = ontology_subsets[0] if ontology_subsets else None + cypher_query = await self.cypher_generator.generate_cypher( + question_components, primary_subset + ) + + logger.debug(f"Generated Cypher: {cypher_query.query}") + + # Execute query + database_type = query_route.backend_type.value + cypher_results = await self.cypher_executor.execute_query( + cypher_query.query, database_type=database_type + ) + + return { + 'generated_query': cypher_query, + 'raw_results': cypher_results + } + + async def get_supported_backends(self) -> List[str]: + """Get list of supported and enabled backends. + + Returns: + List of backend names + """ + return [bt.value for bt in self.backend_router.get_available_backends()] + + async def get_available_ontologies(self) -> List[str]: + """Get list of available ontologies. + + Returns: + List of ontology identifiers + """ + if self.ontology_loader: + return await self.ontology_loader.list_available_ontologies() + return [] + + async def health_check(self) -> Dict[str, Any]: + """Perform health check on all components. + + Returns: + Health status of all components + """ + health = { + 'service': 'healthy', + 'components': {}, + 'backends': {}, + 'ontologies': {} + } + + try: + # Check ontology loader + if self.ontology_loader: + ontologies = await self.ontology_loader.list_available_ontologies() + health['components']['ontology_loader'] = 'healthy' + health['ontologies']['count'] = len(ontologies) + else: + health['components']['ontology_loader'] = 'not_initialized' + + # Check vector store + if self.vector_store: + health['components']['vector_store'] = 'healthy' + health['components']['vector_store_type'] = type(self.vector_store).__name__ + else: + health['components']['vector_store'] = 'not_initialized' + + # Check backends + for backend_type in self.backend_router.get_available_backends(): + if backend_type == BackendType.CASSANDRA and self.sparql_engine: + health['backends']['cassandra'] = 'healthy' + elif backend_type in [BackendType.NEO4J, BackendType.MEMGRAPH, BackendType.FALKORDB] and self.cypher_executor: + health['backends'][backend_type.value] = 'healthy' + else: + health['backends'][backend_type.value] = 'configured_but_not_initialized' + + except Exception as e: + health['service'] = 'degraded' + health['error'] = str(e) + + return health + + async def close(self): + """Close all connections and cleanup resources.""" + try: + if self.sparql_engine: + self.sparql_engine.close() + + if self.cypher_executor: + await self.cypher_executor.close() + + if self.config_store: + # ConfigTableStore cleanup if needed + pass + + logger.info("OntoRAG query service closed") + + except Exception as e: + logger.error(f"Error closing OntoRAG query service: {e}") \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py b/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py new file mode 100644 index 00000000..3e48ac78 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/question_analyzer.py @@ -0,0 +1,364 @@ +""" +Question analyzer for ontology-sensitive query system. +Decomposes user questions into semantic components. +""" + +import logging +import re +from typing import List, Dict, Any, Optional, Tuple +from dataclasses import dataclass +from enum import Enum + +logger = logging.getLogger(__name__) + + +class QuestionType(Enum): + """Types of questions that can be asked.""" + FACTUAL = "factual" # What is X? + RETRIEVAL = "retrieval" # Find all X + AGGREGATION = "aggregation" # How many X? + COMPARISON = "comparison" # Is X better than Y? + RELATIONSHIP = "relationship" # How is X related to Y? + BOOLEAN = "boolean" # Yes/no questions + PROCESS = "process" # How to do X? + TEMPORAL = "temporal" # When did X happen? + SPATIAL = "spatial" # Where is X? + + +@dataclass +class QuestionComponents: + """Components extracted from a question.""" + original_question: str + question_type: QuestionType + entities: List[str] + relationships: List[str] + constraints: List[str] + aggregations: List[str] + expected_answer_type: str + keywords: List[str] + + +class QuestionAnalyzer: + """Analyzes natural language questions to extract semantic components.""" + + def __init__(self): + """Initialize question analyzer.""" + # Question word patterns + self.question_patterns = { + QuestionType.FACTUAL: [ + r'^what\s+(?:is|are)', + r'^who\s+(?:is|are)', + r'^which\s+', + ], + QuestionType.RETRIEVAL: [ + r'^find\s+', + r'^list\s+', + r'^show\s+', + r'^get\s+', + r'^retrieve\s+', + ], + QuestionType.AGGREGATION: [ + r'^how\s+many', + r'^count\s+', + r'^what\s+(?:is|are)\s+the\s+(?:number|total|sum)', + ], + QuestionType.COMPARISON: [ + r'(?:better|worse|more|less|greater|smaller)\s+than', + r'compare\s+', + r'difference\s+between', + ], + QuestionType.RELATIONSHIP: [ + r'^how\s+(?:is|are).*related', + r'relationship\s+between', + r'connection\s+between', + ], + QuestionType.BOOLEAN: [ + r'^(?:is|are|was|were|do|does|did|can|could|will|would|should)', + r'^has\s+', + r'^have\s+', + ], + QuestionType.PROCESS: [ + r'^how\s+(?:to|do)', + r'^explain\s+how', + ], + QuestionType.TEMPORAL: [ + r'^when\s+', + r'what\s+time', + r'what\s+date', + ], + QuestionType.SPATIAL: [ + r'^where\s+', + r'location\s+of', + ], + } + + # Aggregation keywords + self.aggregation_keywords = [ + 'count', 'sum', 'total', 'average', 'mean', 'median', + 'maximum', 'minimum', 'max', 'min', 'number of' + ] + + # Constraint patterns + self.constraint_patterns = [ + r'(?:with|having|where)\s+(.+?)(?:\s+and|\s+or|$)', + r'(?:greater|less|more|fewer)\s+than\s+(\d+)', + r'(?:between|from)\s+(.+?)\s+(?:and|to)\s+(.+)', + r'(?:before|after|since|until)\s+(.+)', + ] + + def analyze(self, question: str) -> QuestionComponents: + """Analyze a question to extract components. + + Args: + question: Natural language question + + Returns: + QuestionComponents with extracted information + """ + # Normalize question + question_lower = question.lower().strip() + + # Determine question type + question_type = self._identify_question_type(question_lower) + + # Extract entities + entities = self._extract_entities(question) + + # Extract relationships + relationships = self._extract_relationships(question_lower) + + # Extract constraints + constraints = self._extract_constraints(question_lower) + + # Extract aggregations + aggregations = self._extract_aggregations(question_lower) + + # Determine expected answer type + answer_type = self._determine_answer_type(question_type, aggregations) + + # Extract keywords + keywords = self._extract_keywords(question_lower) + + return QuestionComponents( + original_question=question, + question_type=question_type, + entities=entities, + relationships=relationships, + constraints=constraints, + aggregations=aggregations, + expected_answer_type=answer_type, + keywords=keywords + ) + + def _identify_question_type(self, question: str) -> QuestionType: + """Identify the type of question. + + Args: + question: Lowercase question text + + Returns: + QuestionType enum value + """ + for q_type, patterns in self.question_patterns.items(): + for pattern in patterns: + if re.search(pattern, question): + return q_type + + # Default to factual + return QuestionType.FACTUAL + + def _extract_entities(self, question: str) -> List[str]: + """Extract potential entities from question. + + Args: + question: Original question text + + Returns: + List of entity strings + """ + entities = [] + + # Extract capitalized words/phrases (potential proper nouns) + # Pattern for consecutive capitalized words + pattern = r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b' + matches = re.findall(pattern, question) + entities.extend(matches) + + # Extract quoted strings + quoted = re.findall(r'"([^"]+)"', question) + entities.extend(quoted) + quoted = re.findall(r"'([^']+)'", question) + entities.extend(quoted) + + # Remove duplicates while preserving order + seen = set() + unique_entities = [] + for entity in entities: + if entity not in seen: + seen.add(entity) + unique_entities.append(entity) + + return unique_entities + + def _extract_relationships(self, question: str) -> List[str]: + """Extract relationship indicators from question. + + Args: + question: Lowercase question text + + Returns: + List of relationship strings + """ + relationships = [] + + # Common relationship patterns + rel_patterns = [ + r'(\w+)\s+(?:of|by|from|to|with|for)\s+', + r'has\s+(\w+)', + r'belongs?\s+to', + r'(?:created|written|authored|owned)\s+by', + r'related\s+to', + r'connected\s+to', + r'associated\s+with', + ] + + for pattern in rel_patterns: + matches = re.findall(pattern, question) + relationships.extend(matches) + + # Clean up + relationships = [r for r in relationships if len(r) > 2] + return list(set(relationships)) + + def _extract_constraints(self, question: str) -> List[str]: + """Extract constraints from question. + + Args: + question: Lowercase question text + + Returns: + List of constraint strings + """ + constraints = [] + + for pattern in self.constraint_patterns: + matches = re.findall(pattern, question) + if matches: + if isinstance(matches[0], tuple): + constraints.extend(list(matches[0])) + else: + constraints.extend(matches) + + # Clean up + constraints = [c.strip() for c in constraints if c and len(c.strip()) > 0] + return constraints + + def _extract_aggregations(self, question: str) -> List[str]: + """Extract aggregation operations from question. + + Args: + question: Lowercase question text + + Returns: + List of aggregation operations + """ + aggregations = [] + + for keyword in self.aggregation_keywords: + if keyword in question: + aggregations.append(keyword) + + return aggregations + + def _determine_answer_type(self, question_type: QuestionType, + aggregations: List[str]) -> str: + """Determine expected answer type. + + Args: + question_type: Type of question + aggregations: Aggregation operations found + + Returns: + Expected answer type string + """ + if aggregations: + if any(a in ['count', 'number of', 'total'] for a in aggregations): + return 'number' + elif any(a in ['average', 'mean', 'median'] for a in aggregations): + return 'number' + elif any(a in ['sum'] for a in aggregations): + return 'number' + + if question_type == QuestionType.BOOLEAN: + return 'boolean' + elif question_type == QuestionType.TEMPORAL: + return 'datetime' + elif question_type == QuestionType.SPATIAL: + return 'location' + elif question_type == QuestionType.RETRIEVAL: + return 'list' + elif question_type == QuestionType.COMPARISON: + return 'comparison' + else: + return 'text' + + def _extract_keywords(self, question: str) -> List[str]: + """Extract important keywords from question. + + Args: + question: Lowercase question text + + Returns: + List of keywords + """ + # Remove common stop words + stop_words = { + 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', + 'for', 'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', + 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', + 'does', 'did', 'will', 'would', 'could', 'should', 'may', + 'might', 'must', 'can', 'shall', 'what', 'which', 'who', + 'when', 'where', 'why', 'how' + } + + # Extract words + words = re.findall(r'\b\w+\b', question) + + # Filter stop words and short words + keywords = [w for w in words if w not in stop_words and len(w) > 2] + + # Remove duplicates while preserving order + seen = set() + unique_keywords = [] + for kw in keywords: + if kw not in seen: + seen.add(kw) + unique_keywords.append(kw) + + return unique_keywords + + def get_question_segments(self, question: str) -> List[str]: + """Split question into segments for embedding. + + Args: + question: Question text + + Returns: + List of question segments + """ + segments = [] + + # Add full question + segments.append(question) + + # Split by clauses + clauses = re.split(r'[,;]', question) + segments.extend([c.strip() for c in clauses if len(c.strip()) > 3]) + + # Extract key phrases + components = self.analyze(question) + segments.extend(components.entities) + segments.extend(components.keywords) + + # Remove duplicates + return list(dict.fromkeys(segments)) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py new file mode 100644 index 00000000..688e7371 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py @@ -0,0 +1,481 @@ +""" +SPARQL-Cassandra engine using Python rdflib. +Executes SPARQL queries against Cassandra using a custom Store implementation. +""" + +import logging +from typing import Dict, Any, List, Optional, Iterator, Tuple +from dataclasses import dataclass +import json + +# Try to import rdflib +try: + from rdflib import Graph, Namespace, URIRef, Literal, BNode + from rdflib.store import Store + from rdflib.plugins.sparql.processor import SPARQLResult + from rdflib.plugins.sparql import prepareQuery + from rdflib.term import Node + RDFLIB_AVAILABLE = True +except ImportError: + RDFLIB_AVAILABLE = False + +# Try to import Cassandra driver +try: + from cassandra.cluster import Cluster + from cassandra.auth import PlainTextAuthProvider + from cassandra.policies import DCAwareRoundRobinPolicy + CASSANDRA_AVAILABLE = True +except ImportError: + CASSANDRA_AVAILABLE = False + +from ....tables.config import ConfigTableStore + +logger = logging.getLogger(__name__) + + +@dataclass +class SPARQLResult: + """Result from SPARQL query execution.""" + bindings: List[Dict[str, Any]] + variables: List[str] + ask_result: Optional[bool] = None # For ASK queries + execution_time: float = 0.0 + query_plan: Optional[str] = None + + +class CassandraTripleStore(Store if RDFLIB_AVAILABLE else object): + """Custom rdflib Store implementation for Cassandra.""" + + def __init__(self, cassandra_config: Dict[str, Any]): + """Initialize Cassandra triple store. + + Args: + cassandra_config: Cassandra connection configuration + """ + if not CASSANDRA_AVAILABLE: + raise RuntimeError("Cassandra driver not available") + if not RDFLIB_AVAILABLE: + raise RuntimeError("rdflib not available") + + super().__init__() + + self.cassandra_config = cassandra_config + self.cluster = None + self.session = None + self.keyspace = cassandra_config.get('keyspace', 'trustgraph') + + # Triple storage table structure + self.triple_table = f"{self.keyspace}.triples" + self.metadata_table = f"{self.keyspace}.triple_metadata" + + def open(self, configuration=None, create=False): + """Open connection to Cassandra.""" + try: + # Create authentication if provided + auth_provider = None + if 'username' in self.cassandra_config and 'password' in self.cassandra_config: + auth_provider = PlainTextAuthProvider( + username=self.cassandra_config['username'], + password=self.cassandra_config['password'] + ) + + # Create cluster + self.cluster = Cluster( + [self.cassandra_config.get('host', 'localhost')], + port=self.cassandra_config.get('port', 9042), + auth_provider=auth_provider, + load_balancing_policy=DCAwareRoundRobinPolicy() + ) + + # Connect + self.session = self.cluster.connect() + + # Ensure keyspace exists + if create: + self._create_schema() + + # Set keyspace + self.session.set_keyspace(self.keyspace) + + logger.info(f"Connected to Cassandra cluster: {self.cassandra_config.get('host')}") + return True + + except Exception as e: + logger.error(f"Failed to connect to Cassandra: {e}") + return False + + def close(self, commit_pending_transaction=True): + """Close Cassandra connection.""" + if self.session: + self.session.shutdown() + if self.cluster: + self.cluster.shutdown() + + def _create_schema(self): + """Create Cassandra schema for triple storage.""" + # Create keyspace + self.session.execute(f""" + CREATE KEYSPACE IF NOT EXISTS {self.keyspace} + WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}} + """) + + # Create triples table optimized for SPARQL queries + self.session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.triple_table} ( + subject text, + predicate text, + object text, + object_datatype text, + object_language text, + is_literal boolean, + graph_id text, + PRIMARY KEY ((subject), predicate, object) + ) + """) + + # Create indexes for efficient querying + self.session.execute(f""" + CREATE INDEX IF NOT EXISTS ON {self.triple_table} (predicate) + """) + self.session.execute(f""" + CREATE INDEX IF NOT EXISTS ON {self.triple_table} (object) + """) + + # Metadata table for graph information + self.session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.metadata_table} ( + graph_id text PRIMARY KEY, + created timestamp, + modified timestamp, + triple_count counter + ) + """) + + def triples(self, triple_pattern, context=None): + """Retrieve triples matching the given pattern. + + Args: + triple_pattern: (subject, predicate, object) pattern with None for variables + context: Graph context (optional) + + Yields: + Matching triples as (subject, predicate, object) tuples + """ + if not self.session: + return + + subject, predicate, object_val = triple_pattern + + # Build CQL query based on pattern + cql_queries = self._pattern_to_cql(subject, predicate, object_val) + + for cql, params in cql_queries: + try: + rows = self.session.execute(cql, params) + for row in rows: + yield self._row_to_triple(row) + except Exception as e: + logger.error(f"Error executing CQL query: {e}") + + def _pattern_to_cql(self, subject, predicate, object_val) -> List[Tuple[str, List]]: + """Convert triple pattern to CQL queries. + + Args: + subject: Subject node or None + predicate: Predicate node or None + object_val: Object node or None + + Returns: + List of (CQL query, parameters) tuples + """ + queries = [] + + # Convert None to wildcard, nodes to strings + s_str = str(subject) if subject else None + p_str = str(predicate) if predicate else None + o_str = str(object_val) if object_val else None + + if s_str and p_str and o_str: + # Specific triple lookup + cql = f"SELECT * FROM {self.triple_table} WHERE subject = ? AND predicate = ? AND object = ?" + queries.append((cql, [s_str, p_str, o_str])) + + elif s_str and p_str: + # Subject and predicate known + cql = f"SELECT * FROM {self.triple_table} WHERE subject = ? AND predicate = ?" + queries.append((cql, [s_str, p_str])) + + elif s_str: + # Subject known + cql = f"SELECT * FROM {self.triple_table} WHERE subject = ?" + queries.append((cql, [s_str])) + + elif p_str: + # Predicate known (requires index scan) + cql = f"SELECT * FROM {self.triple_table} WHERE predicate = ? ALLOW FILTERING" + queries.append((cql, [p_str])) + + elif o_str: + # Object known (requires index scan) + cql = f"SELECT * FROM {self.triple_table} WHERE object = ? ALLOW FILTERING" + queries.append((cql, [o_str])) + + else: + # Full scan (should be avoided in production) + cql = f"SELECT * FROM {self.triple_table}" + queries.append((cql, [])) + + return queries + + def _row_to_triple(self, row): + """Convert Cassandra row to RDF triple. + + Args: + row: Cassandra row object + + Returns: + (subject, predicate, object) tuple with rdflib nodes + """ + # Convert to rdflib nodes + subject = URIRef(row.subject) if row.subject.startswith('http') else BNode(row.subject) + + predicate = URIRef(row.predicate) + + if row.is_literal: + # Create literal with datatype/language + if row.object_datatype: + object_node = Literal(row.object, datatype=URIRef(row.object_datatype)) + elif row.object_language: + object_node = Literal(row.object, lang=row.object_language) + else: + object_node = Literal(row.object) + else: + object_node = URIRef(row.object) if row.object.startswith('http') else BNode(row.object) + + return (subject, predicate, object_node) + + def add(self, triple, context=None, quoted=False): + """Add a triple to the store. + + Args: + triple: (subject, predicate, object) tuple + context: Graph context + quoted: Whether triple is quoted + """ + if not self.session: + return + + subject, predicate, object_val = triple + + # Convert to storage format + s_str = str(subject) + p_str = str(predicate) + + is_literal = isinstance(object_val, Literal) + o_str = str(object_val) + o_datatype = str(object_val.datatype) if is_literal and object_val.datatype else None + o_language = object_val.language if is_literal and object_val.language else None + + # Insert into Cassandra + cql = f""" + INSERT INTO {self.triple_table} + (subject, predicate, object, object_datatype, object_language, is_literal, graph_id) + VALUES (?, ?, ?, ?, ?, ?, ?) + """ + + try: + self.session.execute(cql, [ + s_str, p_str, o_str, o_datatype, o_language, is_literal, + str(context) if context else 'default' + ]) + except Exception as e: + logger.error(f"Error adding triple: {e}") + + def remove(self, triple, context=None): + """Remove a triple from the store. + + Args: + triple: (subject, predicate, object) tuple + context: Graph context + """ + if not self.session: + return + + subject, predicate, object_val = triple + + cql = f""" + DELETE FROM {self.triple_table} + WHERE subject = ? AND predicate = ? AND object = ? + """ + + try: + self.session.execute(cql, [str(subject), str(predicate), str(object_val)]) + except Exception as e: + logger.error(f"Error removing triple: {e}") + + def __len__(self, context=None): + """Get number of triples in store. + + Args: + context: Graph context + + Returns: + Number of triples + """ + if not self.session: + return 0 + + try: + cql = f"SELECT COUNT(*) FROM {self.triple_table}" + result = self.session.execute(cql) + return result.one().count + except Exception as e: + logger.error(f"Error counting triples: {e}") + return 0 + + +class SPARQLCassandraEngine: + """SPARQL processor using Cassandra backend.""" + + def __init__(self, cassandra_config: Dict[str, Any]): + """Initialize SPARQL-Cassandra engine. + + Args: + cassandra_config: Cassandra configuration + """ + if not RDFLIB_AVAILABLE: + raise RuntimeError("rdflib is required for SPARQL processing") + if not CASSANDRA_AVAILABLE: + raise RuntimeError("Cassandra driver is required") + + self.cassandra_config = cassandra_config + self.store = CassandraTripleStore(cassandra_config) + self.graph = Graph(store=self.store) + + # Common namespaces + self.namespaces = { + 'rdf': Namespace('http://www.w3.org/1999/02/22-rdf-syntax-ns#'), + 'rdfs': Namespace('http://www.w3.org/2000/01/rdf-schema#'), + 'owl': Namespace('http://www.w3.org/2002/07/owl#'), + 'xsd': Namespace('http://www.w3.org/2001/XMLSchema#'), + } + + # Bind namespaces to graph + for prefix, namespace in self.namespaces.items(): + self.graph.bind(prefix, namespace) + + async def initialize(self, create_schema=False): + """Initialize the engine. + + Args: + create_schema: Whether to create Cassandra schema + """ + success = self.store.open(create=create_schema) + if not success: + raise RuntimeError("Failed to connect to Cassandra") + + logger.info("SPARQL-Cassandra engine initialized") + + def execute_sparql(self, sparql_query: str) -> SPARQLResult: + """Execute SPARQL query against Cassandra. + + Args: + sparql_query: SPARQL query string + + Returns: + Query results + """ + import time + start_time = time.time() + + try: + # Prepare and execute query + prepared_query = prepareQuery(sparql_query) + result = self.graph.query(prepared_query) + + execution_time = time.time() - start_time + + # Format results based on query type + if sparql_query.strip().upper().startswith('ASK'): + return SPARQLResult( + bindings=[], + variables=[], + ask_result=bool(result), + execution_time=execution_time + ) + else: + # SELECT query + bindings = [] + variables = result.vars if hasattr(result, 'vars') else [] + + for row in result: + binding = {} + for i, var in enumerate(variables): + if i < len(row): + value = row[i] + binding[str(var)] = self._format_result_value(value) + bindings.append(binding) + + return SPARQLResult( + bindings=bindings, + variables=[str(v) for v in variables], + execution_time=execution_time + ) + + except Exception as e: + logger.error(f"SPARQL execution error: {e}") + return SPARQLResult( + bindings=[], + variables=[], + execution_time=time.time() - start_time + ) + + def _format_result_value(self, value): + """Format result value for output. + + Args: + value: RDF value (URIRef, Literal, BNode) + + Returns: + Formatted value + """ + if isinstance(value, URIRef): + return {'type': 'uri', 'value': str(value)} + elif isinstance(value, Literal): + result = {'type': 'literal', 'value': str(value)} + if value.datatype: + result['datatype'] = str(value.datatype) + if value.language: + result['language'] = value.language + return result + elif isinstance(value, BNode): + return {'type': 'bnode', 'value': str(value)} + else: + return {'type': 'unknown', 'value': str(value)} + + def load_triples_from_store(self, config_store: ConfigTableStore): + """Load triples from TrustGraph's storage into the RDF graph. + + Args: + config_store: Configuration store with triples + """ + # This would need to be implemented based on how triples are stored + # in TrustGraph's Cassandra tables + logger.info("Loading triples from TrustGraph store...") + + # Example implementation - would need to be adapted + # to actual TrustGraph storage format + try: + # Get all triple data + # This is a placeholder - actual implementation would need + # to query the appropriate TrustGraph tables + pass + + except Exception as e: + logger.error(f"Error loading triples: {e}") + + def close(self): + """Close the engine and connections.""" + if self.store: + self.store.close() + logger.info("SPARQL-Cassandra engine closed") \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py b/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py new file mode 100644 index 00000000..44c7e0a1 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/ontology/sparql_generator.py @@ -0,0 +1,487 @@ +""" +SPARQL query generator for ontology-sensitive queries. +Converts natural language questions to SPARQL queries for Cassandra execution. +""" + +import logging +from typing import Dict, Any, List, Optional +from dataclasses import dataclass + +from .question_analyzer import QuestionComponents, QuestionType +from .ontology_matcher import QueryOntologySubset + +logger = logging.getLogger(__name__) + + +@dataclass +class SPARQLQuery: + """Generated SPARQL query with metadata.""" + query: str + variables: List[str] + query_type: str # SELECT, ASK, CONSTRUCT, DESCRIBE + explanation: str + complexity_score: float + + +class SPARQLGenerator: + """Generates SPARQL queries from natural language questions using LLM assistance.""" + + def __init__(self, prompt_service=None): + """Initialize SPARQL generator. + + Args: + prompt_service: Service for LLM-based query generation + """ + self.prompt_service = prompt_service + + # SPARQL query templates for common patterns + self.templates = { + 'simple_class_query': """ +PREFIX : <{namespace}> +PREFIX rdf: +PREFIX rdfs: + +SELECT ?entity ?label WHERE {{ + ?entity rdf:type :{class_name} . + OPTIONAL {{ ?entity rdfs:label ?label }} +}}""", + + 'property_query': """ +PREFIX : <{namespace}> +PREFIX rdf: +PREFIX rdfs: + +SELECT ?subject ?object WHERE {{ + ?subject :{property} ?object . + ?subject rdf:type :{subject_class} . +}}""", + + 'hierarchy_query': """ +PREFIX : <{namespace}> +PREFIX rdf: +PREFIX rdfs: + +SELECT ?subclass ?superclass WHERE {{ + ?subclass rdfs:subClassOf* ?superclass . + ?superclass rdf:type :{root_class} . +}}""", + + 'count_query': """ +PREFIX : <{namespace}> +PREFIX rdf: + +SELECT (COUNT(?entity) AS ?count) WHERE {{ + ?entity rdf:type :{class_name} . + {additional_constraints} +}}""", + + 'boolean_query': """ +PREFIX : <{namespace}> +PREFIX rdf: + +ASK {{ + {triple_pattern} +}}""" + } + + async def generate_sparql(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> SPARQLQuery: + """Generate SPARQL query for a question. + + Args: + question_components: Analyzed question components + ontology_subset: Relevant ontology subset + + Returns: + Generated SPARQL query + """ + # Try template-based generation first + template_query = self._try_template_generation(question_components, ontology_subset) + if template_query: + logger.debug("Generated SPARQL using template") + return template_query + + # Fall back to LLM-based generation + if self.prompt_service: + llm_query = await self._generate_with_llm(question_components, ontology_subset) + if llm_query: + logger.debug("Generated SPARQL using LLM") + return llm_query + + # Final fallback to simple pattern + logger.warning("Falling back to simple SPARQL pattern") + return self._generate_fallback_query(question_components, ontology_subset) + + def _try_template_generation(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> Optional[SPARQLQuery]: + """Try to generate query using templates. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Generated query or None if no template matches + """ + namespace = ontology_subset.metadata.get('namespace', 'http://example.org/') + + # Simple class query (What are the animals?) + if (question_components.question_type == QuestionType.RETRIEVAL and + len(question_components.entities) == 1 and + question_components.entities[0].lower() in [c.lower() for c in ontology_subset.classes]): + + class_name = self._find_matching_class(question_components.entities[0], ontology_subset) + if class_name: + query = self.templates['simple_class_query'].format( + namespace=namespace, + class_name=class_name + ) + return SPARQLQuery( + query=query, + variables=['entity', 'label'], + query_type='SELECT', + explanation=f"Retrieve all instances of {class_name}", + complexity_score=0.3 + ) + + # Count query (How many animals are there?) + if (question_components.question_type == QuestionType.AGGREGATION and + 'count' in question_components.aggregations and + len(question_components.entities) >= 1): + + class_name = self._find_matching_class(question_components.entities[0], ontology_subset) + if class_name: + query = self.templates['count_query'].format( + namespace=namespace, + class_name=class_name, + additional_constraints=self._build_constraints(question_components, ontology_subset) + ) + return SPARQLQuery( + query=query, + variables=['count'], + query_type='SELECT', + explanation=f"Count instances of {class_name}", + complexity_score=0.4 + ) + + # Boolean query (Is X a Y?) + if question_components.question_type == QuestionType.BOOLEAN: + triple_pattern = self._build_boolean_pattern(question_components, ontology_subset) + if triple_pattern: + query = self.templates['boolean_query'].format( + namespace=namespace, + triple_pattern=triple_pattern + ) + return SPARQLQuery( + query=query, + variables=[], + query_type='ASK', + explanation="Boolean query for fact checking", + complexity_score=0.2 + ) + + return None + + async def _generate_with_llm(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> Optional[SPARQLQuery]: + """Generate SPARQL using LLM. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Generated query or None if failed + """ + try: + prompt = self._build_sparql_prompt(question_components, ontology_subset) + response = await self.prompt_service.generate_sparql(prompt=prompt) + + if response and isinstance(response, dict): + query = response.get('query', '').strip() + if query.upper().startswith(('SELECT', 'ASK', 'CONSTRUCT', 'DESCRIBE')): + return SPARQLQuery( + query=query, + variables=self._extract_variables(query), + query_type=query.split()[0].upper(), + explanation=response.get('explanation', 'Generated by LLM'), + complexity_score=self._calculate_complexity(query) + ) + + except Exception as e: + logger.error(f"LLM SPARQL generation failed: {e}") + + return None + + def _build_sparql_prompt(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> str: + """Build prompt for LLM SPARQL generation. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Formatted prompt string + """ + namespace = ontology_subset.metadata.get('namespace', 'http://example.org/') + + # Format ontology elements + classes_str = self._format_classes_for_prompt(ontology_subset.classes, namespace) + props_str = self._format_properties_for_prompt( + ontology_subset.object_properties, + ontology_subset.datatype_properties, + namespace + ) + + prompt = f"""Generate a SPARQL query for the following question using the provided ontology. + +QUESTION: {question_components.original_question} + +ONTOLOGY NAMESPACE: {namespace} + +AVAILABLE CLASSES: +{classes_str} + +AVAILABLE PROPERTIES: +{props_str} + +RULES: +- Use proper SPARQL syntax +- Include appropriate prefixes +- Use property paths for hierarchical queries (rdfs:subClassOf*) +- Add FILTER clauses for constraints +- Optimize for Cassandra backend +- Return both query and explanation + +QUERY TYPE HINTS: +- Question type: {question_components.question_type.value} +- Expected answer: {question_components.expected_answer_type} +- Entities mentioned: {', '.join(question_components.entities)} +- Aggregations: {', '.join(question_components.aggregations)} + +Generate a complete SPARQL query:""" + + return prompt + + def _generate_fallback_query(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> SPARQLQuery: + """Generate simple fallback query. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + Basic SPARQL query + """ + namespace = ontology_subset.metadata.get('namespace', 'http://example.org/') + + # Very basic SELECT query + query = f"""PREFIX : <{namespace}> +PREFIX rdf: +PREFIX rdfs: + +SELECT ?subject ?predicate ?object WHERE {{ + ?subject ?predicate ?object . + FILTER(CONTAINS(STR(?subject), "{question_components.keywords[0] if question_components.keywords else 'entity'}")) +}} +LIMIT 10""" + + return SPARQLQuery( + query=query, + variables=['subject', 'predicate', 'object'], + query_type='SELECT', + explanation="Fallback query for basic pattern matching", + complexity_score=0.1 + ) + + def _find_matching_class(self, entity: str, ontology_subset: QueryOntologySubset) -> Optional[str]: + """Find matching class in ontology subset. + + Args: + entity: Entity string to match + ontology_subset: Ontology subset + + Returns: + Matching class name or None + """ + entity_lower = entity.lower() + + # Direct match + for class_id in ontology_subset.classes: + if class_id.lower() == entity_lower: + return class_id + + # Label match + for class_id, class_def in ontology_subset.classes.items(): + labels = class_def.get('labels', []) + for label in labels: + if isinstance(label, dict): + label_value = label.get('value', '').lower() + if label_value == entity_lower: + return class_id + + # Partial match + for class_id in ontology_subset.classes: + if entity_lower in class_id.lower() or class_id.lower() in entity_lower: + return class_id + + return None + + def _build_constraints(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> str: + """Build constraint clauses for SPARQL. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + SPARQL constraint string + """ + constraints = [] + + for constraint in question_components.constraints: + # Simple constraint patterns + if 'greater than' in constraint.lower(): + # Extract number + import re + numbers = re.findall(r'\d+', constraint) + if numbers: + constraints.append(f"FILTER(?value > {numbers[0]})") + + elif 'less than' in constraint.lower(): + numbers = re.findall(r'\d+', constraint) + if numbers: + constraints.append(f"FILTER(?value < {numbers[0]})") + + return '\n '.join(constraints) + + def _build_boolean_pattern(self, + question_components: QuestionComponents, + ontology_subset: QueryOntologySubset) -> Optional[str]: + """Build triple pattern for boolean queries. + + Args: + question_components: Question analysis + ontology_subset: Ontology subset + + Returns: + SPARQL triple pattern or None + """ + if len(question_components.entities) >= 2: + subject = question_components.entities[0] + object_val = question_components.entities[1] + + # Try to find connecting property + for prop_id in ontology_subset.object_properties: + return f":{subject} :{prop_id} :{object_val} ." + + # Fallback to type check + return f":{subject} rdf:type :{object_val} ." + + return None + + def _format_classes_for_prompt(self, classes: Dict[str, Any], namespace: str) -> str: + """Format classes for prompt. + + Args: + classes: Classes dictionary + namespace: Ontology namespace + + Returns: + Formatted classes string + """ + if not classes: + return "None" + + lines = [] + for class_id, definition in classes.items(): + comment = definition.get('comment', '') + parent = definition.get('subclass_of', 'Thing') + lines.append(f"- :{class_id} (subclass of :{parent}) - {comment}") + + return '\n'.join(lines) + + def _format_properties_for_prompt(self, + object_props: Dict[str, Any], + datatype_props: Dict[str, Any], + namespace: str) -> str: + """Format properties for prompt. + + Args: + object_props: Object properties + datatype_props: Datatype properties + namespace: Ontology namespace + + Returns: + Formatted properties string + """ + lines = [] + + for prop_id, definition in object_props.items(): + domain = definition.get('domain', 'Any') + range_val = definition.get('range', 'Any') + comment = definition.get('comment', '') + lines.append(f"- :{prop_id} (:{domain} -> :{range_val}) - {comment}") + + for prop_id, definition in datatype_props.items(): + domain = definition.get('domain', 'Any') + range_val = definition.get('range', 'xsd:string') + comment = definition.get('comment', '') + lines.append(f"- :{prop_id} (:{domain} -> {range_val}) - {comment}") + + return '\n'.join(lines) if lines else "None" + + def _extract_variables(self, query: str) -> List[str]: + """Extract variables from SPARQL query. + + Args: + query: SPARQL query string + + Returns: + List of variable names + """ + import re + variables = re.findall(r'\?(\w+)', query) + return list(set(variables)) + + def _calculate_complexity(self, query: str) -> float: + """Calculate complexity score for SPARQL query. + + Args: + query: SPARQL query string + + Returns: + Complexity score (0.0 to 1.0) + """ + complexity = 0.0 + + # Count different SPARQL features + query_upper = query.upper() + + if 'JOIN' in query_upper or 'UNION' in query_upper: + complexity += 0.3 + if 'FILTER' in query_upper: + complexity += 0.2 + if 'OPTIONAL' in query_upper: + complexity += 0.1 + if 'GROUP BY' in query_upper: + complexity += 0.2 + if 'ORDER BY' in query_upper: + complexity += 0.1 + if '*' in query: # Property paths + complexity += 0.1 + + # Count variables + variables = self._extract_variables(query) + complexity += len(variables) * 0.05 + + return min(complexity, 1.0) \ No newline at end of file