@@ -430,6 +430,118 @@ def __init__(
430430 )
431431 self ._llm_params : dict [str , Any ] = llm_params or {}
432432
433+ def _filter_invalid_patterns (
434+ self ,
435+ patterns : List [Tuple [str , str , str ]],
436+ node_types : List [Dict [str , Any ]],
437+ relationship_types : Optional [List [Dict [str , Any ]]] = None ,
438+ ) -> List [Tuple [str , str , str ]]:
439+ """
440+ Filter out patterns that reference undefined node types or relationship types.
441+
442+ Args:
443+ patterns: List of patterns to filter.
444+ node_types: List of node type definitions.
445+ relationship_types: Optional list of relationship type definitions.
446+
447+ Returns:
448+ Filtered list of patterns containing only valid references.
449+ """
450+ # Early returns for missing required types
451+ if not node_types :
452+ logging .info (
453+ "Filtering out all patterns because no node types are defined. "
454+ "Patterns reference node types that must be defined."
455+ )
456+ return []
457+
458+ if not relationship_types :
459+ logging .info (
460+ "Filtering out all patterns because no relationship types are defined. "
461+ "GraphSchema validation requires relationship_types when patterns are provided."
462+ )
463+ return []
464+
465+ # Create sets of valid labels
466+ valid_node_labels = {node_type ["label" ] for node_type in node_types }
467+ valid_relationship_labels = {
468+ rel_type ["label" ] for rel_type in relationship_types
469+ }
470+
471+ # Filter patterns
472+ filtered_patterns = []
473+ for pattern in patterns :
474+ if not (isinstance (pattern , (list , tuple )) and len (pattern ) == 3 ):
475+ continue
476+
477+ entity1 , relation , entity2 = pattern
478+
479+ # Check if all components are valid
480+ if (
481+ entity1 in valid_node_labels
482+ and entity2 in valid_node_labels
483+ and relation in valid_relationship_labels
484+ ):
485+ filtered_patterns .append (pattern )
486+ else :
487+ # Log invalid pattern with validation details
488+ entity1_valid = entity1 in valid_node_labels
489+ entity2_valid = entity2 in valid_node_labels
490+ relation_valid = relation in valid_relationship_labels
491+
492+ logging .info (
493+ f"Filtering out invalid pattern: { pattern } . "
494+ f"Entity1 '{ entity1 } ' valid: { entity1_valid } , "
495+ f"Entity2 '{ entity2 } ' valid: { entity2_valid } , "
496+ f"Relation '{ relation } ' valid: { relation_valid } "
497+ )
498+
499+ return filtered_patterns
500+
501+ def _filter_nodes_without_labels (
502+ self , node_types : List [Dict [str , Any ]]
503+ ) -> List [Dict [str , Any ]]:
504+ """
505+ Filter out node types that have no labels.
506+
507+ Args:
508+ node_types: List of node type definitions.
509+
510+ Returns:
511+ Filtered list of node types containing only those with valid labels.
512+ """
513+ filtered_nodes = []
514+ for node_type in node_types :
515+ if node_type .get ("label" ):
516+ filtered_nodes .append (node_type )
517+ else :
518+ logging .info (f"Filtering out node type with missing label: { node_type } " )
519+
520+ return filtered_nodes
521+
522+ def _filter_relationships_without_labels (
523+ self , relationship_types : List [Dict [str , Any ]]
524+ ) -> List [Dict [str , Any ]]:
525+ """
526+ Filter out relationship types that have no labels.
527+
528+ Args:
529+ relationship_types: List of relationship type definitions.
530+
531+ Returns:
532+ Filtered list of relationship types containing only those with valid labels.
533+ """
534+ filtered_relationships = []
535+ for rel_type in relationship_types :
536+ if rel_type .get ("label" ):
537+ filtered_relationships .append (rel_type )
538+ else :
539+ logging .info (
540+ f"Filtering out relationship type with missing label: { rel_type } "
541+ )
542+
543+ return filtered_relationships
544+
433545 @validate_call
434546 async def run (self , text : str , examples : str = "" , ** kwargs : Any ) -> GraphSchema :
435547 """
@@ -459,13 +571,13 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
459571 pass # Keep as is
460572 # handle list
461573 elif isinstance (extracted_schema , list ):
462- if len (extracted_schema ) > 0 and isinstance (extracted_schema [0 ], dict ):
463- extracted_schema = extracted_schema [0 ]
464- elif len (extracted_schema ) == 0 :
465- logging .warning (
574+ if len (extracted_schema ) == 0 :
575+ logging .info (
466576 "LLM returned an empty list for schema. Falling back to empty schema."
467577 )
468578 extracted_schema = {}
579+ elif isinstance (extracted_schema [0 ], dict ):
580+ extracted_schema = extracted_schema [0 ]
469581 else :
470582 raise SchemaExtractionError (
471583 f"Expected a dictionary or list of dictionaries, but got list containing: { type (extracted_schema [0 ])} "
@@ -488,6 +600,19 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
488600 "patterns"
489601 )
490602
603+ # Filter out nodes and relationships without labels
604+ extracted_node_types = self ._filter_nodes_without_labels (extracted_node_types )
605+ if extracted_relationship_types :
606+ extracted_relationship_types = self ._filter_relationships_without_labels (
607+ extracted_relationship_types
608+ )
609+
610+ # Filter out invalid patterns before validation
611+ if extracted_patterns :
612+ extracted_patterns = self ._filter_invalid_patterns (
613+ extracted_patterns , extracted_node_types , extracted_relationship_types
614+ )
615+
491616 return GraphSchema .model_validate (
492617 {
493618 "node_types" : extracted_node_types ,
0 commit comments