diff --git a/.github/workflows/docker-master.yaml b/.github/workflows/docker-master.yaml index 086bb48..90487db 100644 --- a/.github/workflows/docker-master.yaml +++ b/.github/workflows/docker-master.yaml @@ -4,12 +4,13 @@ on: branches: [master] pull_request: branches: [master] + jobs: ci: strategy: fail-fast: false matrix: - python-version: [3.9, "3.10", "3.11"] + python-version: [3.9, "3.10", "3.11", "3.12"] poetry-version: [1.4.2] os: [ubuntu-latest, windows-latest] runs-on: ${{ matrix.os }} diff --git a/pyproject.toml b/pyproject.toml index 90f9c0c..6823c93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "turms" -version = "0.5.0" +version = "0.7.0" description = "graphql-codegen powered by pydantic" authors = ["jhnnsrs "] license = "MIT" diff --git a/tests/test_multi_interface.py b/tests/test_multi_interface.py index 7f69c78..b0404b6 100644 --- a/tests/test_multi_interface.py +++ b/tests/test_multi_interface.py @@ -83,5 +83,5 @@ def test_fragment_generation(multi_interface_schema): unit_test_with( generated_ast, - "assert FlowNodeBaseReactiveNode(id='soinosins', position={'x': 3, 'y': 3}).id, 'Needs to be not nown'", + "assert FlowNodesBaseReactiveNode(id='soinosins', position={'x': 3, 'y': 3}).id, 'Needs to be not nown'", ) diff --git a/turms/config.py b/turms/config.py index a4db2f3..fa47e56 100644 --- a/turms/config.py +++ b/turms/config.py @@ -1,5 +1,13 @@ import builtins -from pydantic import AnyHttpUrl, BaseModel, Field, GetCoreSchemaHandler, field_validator, validator, ConfigDict +from pydantic import ( + AnyHttpUrl, + BaseModel, + Field, + GetCoreSchemaHandler, + field_validator, + validator, + ConfigDict, +) from pydantic_core import core_schema from pydantic_settings import BaseSettings, SettingsConfigDict from typing import ( @@ -22,7 +30,6 @@ class ConfigProxy(BaseModel): type: str - class ImportableFunctionMixin(Protocol): @classmethod @@ -33,7 +40,6 @@ def __get_pydantic_core_schema__( cls.validate, handler(callable), field_name=handler.field_name ) - @classmethod def validate(cls, v, *info): if not callable(v): @@ -57,7 +63,6 @@ def __get_pydantic_core_schema__( cls.validate, handler(str), field_name=handler.field_name ) - @classmethod def validate(cls, v, *info): if not isinstance(v, str): @@ -147,7 +152,7 @@ class OptionsConfig(BaseSettings): enabled: bool = Field(False, description="Enabling this, will freeze the schema") """Enabling this, will freeze the schema""" - extra: ExtraOptions = None + extra: ExtraOptions = None """Extra options for pydantic""" allow_mutation: Optional[bool] = None """Allow mutation""" @@ -179,6 +184,7 @@ class OptionsConfig(BaseSettings): PydanticVersion = Literal["v1", "v2"] + class GeneratorConfig(BaseSettings): """Configuration for the generator @@ -189,11 +195,10 @@ class GeneratorConfig(BaseSettings): and the scalars that should be used. """ + model_config: SettingsConfigDict = SettingsConfigDict( env_prefix="TURMS_", extra="forbid", - - ) pydantic_version: PydanticVersion = "v2" @@ -291,7 +296,6 @@ def validate_importable(cls, v): return v - class Extensions(BaseModel): """Wrapping class to be able to extract the tums configuraiton""" @@ -317,6 +321,7 @@ class GraphQLProject(BaseSettings): Turm will use the schema and documents to generate the python models, according to the generator configuration under extensions.turms """ + model_config: SettingsConfigDict = SettingsConfigDict( env_prefix="TURMS_GRAPHQL_", extra="allow", @@ -335,6 +340,7 @@ class GraphQLConfigMultiple(BaseSettings): This is the main configuration for multiple GraphQL projects. It is compliant with the graphql-config specification for multiple projec.""" + model_config: SettingsConfigDict = SettingsConfigDict( extra="allow", ) @@ -343,13 +349,13 @@ class GraphQLConfigMultiple(BaseSettings): """ The projects that should be parsed. The key is the name of the project and the value is the graphql project""" - class GraphQLConfigSingle(GraphQLProject): """Configuration for a single GraphQL project This is the main configuration for a single GraphQL project. It is compliant with the graphql-config specification for a single project. """ + model_config: SettingsConfigDict = SettingsConfigDict( extra="allow", ) diff --git a/turms/helpers.py b/turms/helpers.py index c983fa0..8861820 100644 --- a/turms/helpers.py +++ b/turms/helpers.py @@ -96,9 +96,11 @@ def load_dsl_from_url(url: AnyHttpUrl, headers: Dict[str, str] = None) -> DSLStr default_headers.update(headers) try: req = requests.get(url, headers=default_headers) - x = req.text() - except Exception: - raise GenerationError(f"Failed to fetch schema from {url}") + assert req.status_code == 200, "Incorrect status code" + assert req.content, "No content" + x = req.content.decode() + except Exception as e: + raise GenerationError(f"Failed to fetch schema from {url}") from e return x diff --git a/turms/parsers/base.py b/turms/parsers/base.py index bd227b3..77a23eb 100644 --- a/turms/parsers/base.py +++ b/turms/parsers/base.py @@ -9,7 +9,6 @@ class ParserConfig(BaseSettings): type: str - class Parser(BaseModel): """Base class for all parsers diff --git a/turms/parsers/polyfill.py b/turms/parsers/polyfill.py index 9d98648..32759c5 100644 --- a/turms/parsers/polyfill.py +++ b/turms/parsers/polyfill.py @@ -22,7 +22,6 @@ def validate_python_version(cls, value): return value - def polyfill_python_seven( asts: List[ast.AST], config: PolyfillPluginConfig ) -> List[ast.AST]: diff --git a/turms/plugins/base.py b/turms/plugins/base.py index 3211b0c..d051ad7 100644 --- a/turms/plugins/base.py +++ b/turms/plugins/base.py @@ -13,7 +13,6 @@ class PluginConfig(BaseSettings): type: str - class Plugin(BaseModel): """ Base class for all plugins @@ -21,6 +20,7 @@ class Plugin(BaseModel): Plugins are the workhorse of turms. They are used to generate python code, according to the GraphQL schema. You can use plugins to generate python code for your GraphQL schema. THe all received the graphql schema and the config of the plugin.""" + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) config: PluginConfig log: LogFunction = Field(default=lambda *args, **kwargs: print(*args)) diff --git a/turms/plugins/dd.py b/turms/plugins/dd.py new file mode 100644 index 0000000..99c6994 --- /dev/null +++ b/turms/plugins/dd.py @@ -0,0 +1,326 @@ +import ast +from typing import List, Optional + +from pydantic_settings import SettingsConfigDict +from turms.config import GeneratorConfig +from graphql.utilities.build_client_schema import GraphQLSchema +from turms.recurse import type_field_node +from turms.plugins.base import Plugin, PluginConfig +from pydantic import Field +from graphql.language.ast import FragmentDefinitionNode +from turms.registry import ClassRegistry +from turms.utils import ( + generate_pydantic_config, + generate_typename_field, + get_additional_bases_for_type, + get_interface_bases, + parse_documents, +) +from graphql import ( + FieldNode, + FragmentSpreadNode, + GraphQLInterfaceType, + GraphQLObjectType, + InlineFragmentNode, + language, +) +from turms.config import GraphQLTypes + +from graphql.utilities.type_info import get_field_def +import logging + + +logger = logging.getLogger(__name__) + + +class FragmentsPluginConfig(PluginConfig): + model_config = SettingsConfigDict(env_prefix="TURMS_PLUGINS_FRAGMENTS_") + type: str = "turms.plugins.fragments.FragmentsPlugin" + fragment_bases: List[str] = [] + fragments_glob: Optional[str] = None + + +def get_fragment_bases( + config: GeneratorConfig, + pluginConfig: FragmentsPluginConfig, + registry: ClassRegistry, +): + if pluginConfig.fragment_bases: + for base in pluginConfig.fragment_bases: + registry.register_import(base) + + return [ + ast.Name(id=base.split(".")[-1], ctx=ast.Load()) + for base in pluginConfig.fragment_bases + ] + + else: + for base in config.object_bases: + registry.register_import(base) + + return [ + ast.Name(id=base.split(".")[-1], ctx=ast.Load()) + for base in config.object_bases + ] + + +def generate_fragment( + f: FragmentDefinitionNode, + client_schema: GraphQLSchema, + config: GeneratorConfig, + plugin_config: FragmentsPluginConfig, + registry: ClassRegistry, +): + tree = [] + fields = [] + type = client_schema.get_type(f.type_condition.name.value) + name = registry.generate_fragment( + f.name.value, isinstance(type, GraphQLInterfaceType) + ) + + registry.register_fragment_document( + f.name.value, language.print_ast(f) + ) # TODO: Check if typename is being referenced? so that we can check between the elements of the interface + + if isinstance(type, GraphQLInterfaceType): + mother_class_fields = [] + base_fragment_name = f"{name}" + additional_bases = get_additional_bases_for_type(type.name, config, registry) + + if type.description: + mother_class_fields.append( + ast.Expr(value=ast.Constant(value=type.description)) + ) + + for sub_node in f.selection_set.selections: + + if isinstance(sub_node, FieldNode): + + if sub_node.name.value == "__typename": + continue + + field_type = type.fields[sub_node.name.value] + mother_class_fields += type_field_node( + sub_node, + base_fragment_name, + field_type, + client_schema, + config, + tree, + registry, + ) + + mother_class = ast.ClassDef( + base_fragment_name, + bases=get_interface_bases(config, registry) + + additional_bases, # Todo: fill with base + decorator_list=[], + keywords=[], + body=mother_class_fields if mother_class_fields else [ast.Pass()], + ) + + tree.append(mother_class) + + union_class_names = [] + + for sub_node in f.selection_set.selections: + + if isinstance(sub_node, FragmentSpreadNode): + # Spread nodes are like inheritance? + spreaded_name = f"{base_fragment_name}{sub_node.name.value}" + + cls = ast.ClassDef( + spreaded_name, + bases=[ + ast.Name( + id=registry.inherit_fragment(sub_node.name.value), + ctx=ast.Load(), + ), + ast.Name(id=base_fragment_name, ctx=ast.Load()), + ], + decorator_list=[], + keywords=[], + body=[ast.Pass()], + ) + + tree.append(cls) + union_class_names.append(spreaded_name) + + if isinstance(sub_node, InlineFragmentNode): + inline_name = ( + f"{base_fragment_name}{sub_node.type_condition.name.value}Fragment" + ) + inline_fragment_fields = [] + + inline_fragment_fields += [ + generate_typename_field( + sub_node.type_condition.name.value, registry, config + ) + ] + + for sub_sub_node in sub_node.selection_set.selections: + + if isinstance(sub_sub_node, FieldNode): + sub_sub_node_type = client_schema.get_type( + sub_node.type_condition.name.value + ) + + if sub_sub_node.name.value == "__typename": + continue + + field_type = sub_sub_node_type.fields[sub_sub_node.name.value] + inline_fragment_fields += type_field_node( + sub_sub_node, + inline_name, + field_type, + client_schema, + config, + tree, + registry, + ) + + additional_bases = get_additional_bases_for_type( + sub_node.type_condition.name.value, config, registry + ) + cls = ast.ClassDef( + inline_name, + bases=[ + ast.Name(id=base_fragment_name, ctx=ast.Load()), + ] + + additional_bases, + decorator_list=[], + keywords=[], + body=inline_fragment_fields + + generate_pydantic_config( + GraphQLTypes.FRAGMENT, + config, + registry, + sub_node.type_condition.name.value, + ), + ) + + tree.append(cls) + union_class_names.append(inline_name) + + union_class_names.append(base_fragment_name) + + if len(union_class_names) > 1: + registry.register_import("typing.Union") + slice = ast.Tuple( + elts=[ + ast.Name(id=clsname, ctx=ast.Load()) + for clsname in union_class_names + ], + ctx=ast.Load(), + ) + tree.append( + ast.Assign( + targets=[ + ast.Name( + id=registry.style_fragment_class(f.name.value), + ctx=ast.Store(), + ) + ], + value=ast.Subscript( + value=ast.Name("Union", ctx=ast.Load()), + slice=slice, + ctx=ast.Load(), + ), + ) + ) + + return tree + + elif isinstance(type, GraphQLObjectType): + additional_bases = get_additional_bases_for_type( + f.type_condition.name.value, config, registry + ) + + fields += [generate_typename_field(type.name, registry, config)] + + for field in f.selection_set.selections: + + if field.name.value == "__typename": + continue + + if isinstance(field, FragmentDefinitionNode): # pragma: no cover + + continue + + if isinstance(field, FragmentSpreadNode): + additional_bases = [ + ast.Name( + id=registry.inherit_fragment(field.name.value), + ctx=ast.Load(), + ) + ] + additional_bases # needs to be prepended (MRO) + continue + + field_definition = get_field_def(client_schema, type, field) + assert field_definition, "Couldn't find field definition" + + fields += type_field_node( + field, + name, + field_definition, + client_schema, + config, + tree, + registry, + ) + + tree.append( + ast.ClassDef( + name, + bases=additional_bases + + get_fragment_bases(config, plugin_config, registry), + decorator_list=[], + keywords=[], + body=fields + + generate_pydantic_config(GraphQLTypes.FRAGMENT, config, registry), + ) + ) + return tree + + +class FragmentsPlugin(Plugin): + """Plugin for generating fragments from + documents + + The fragments plugin will generate classes for each fragment. It loads the documents, + scans for fragments and generates the classes. + + If encountering a fragment on an interface it will generate a BASE class for that interface + and then generate a class for each type referenced in the fragment. They will all inherit + from the base class. The true type will be determined at runtime as all of the potential subtypes + will be in the same union. + + """ + + config: FragmentsPluginConfig = Field(default_factory=FragmentsPluginConfig) + + def generate_ast( + self, + client_schema: GraphQLSchema, + config: GeneratorConfig, + registry: ClassRegistry, + ) -> List[ast.AST]: + + plugin_tree = [] + + documents = parse_documents( + client_schema, self.config.fragments_glob or config.documents + ) + + definitions = documents.definitions + + fragments = [ + node for node in definitions if isinstance(node, FragmentDefinitionNode) + ] + + for fragment in fragments: + plugin_tree += generate_fragment( + fragment, client_schema, config, self.config, registry + ) + + return plugin_tree diff --git a/turms/plugins/enums.py b/turms/plugins/enums.py index 6f42706..7ba2239 100644 --- a/turms/plugins/enums.py +++ b/turms/plugins/enums.py @@ -19,7 +19,7 @@ class EnumsPluginsError(Exception): class EnumsPluginConfig(PluginConfig): - model_config = SettingsConfigDict(env_prefix = "TURMS_PLUGINS_ENUMS_") + model_config = SettingsConfigDict(env_prefix="TURMS_PLUGINS_ENUMS_") type: str = "turms.plugins.enums.EnumsPlugin" skip_underscore: bool = False skip_double_underscore: bool = True @@ -27,6 +27,7 @@ class EnumsPluginConfig(PluginConfig): prepend: str = "" append: str = "" + def generate_enums( client_schema: GraphQLSchema, config: GeneratorConfig, diff --git a/turms/plugins/fragments.py b/turms/plugins/fragments.py index cc31559..9700678 100644 --- a/turms/plugins/fragments.py +++ b/turms/plugins/fragments.py @@ -10,10 +10,12 @@ from graphql.language.ast import FragmentDefinitionNode from turms.registry import ClassRegistry from turms.utils import ( + generate_generic_typename_field, generate_pydantic_config, generate_typename_field, get_additional_bases_for_type, get_interface_bases, + non_typename_fields, parse_documents, ) from graphql import ( @@ -22,9 +24,82 @@ GraphQLInterfaceType, GraphQLObjectType, InlineFragmentNode, + SelectionSetNode, language, ) from turms.config import GraphQLTypes +from graphql import parse, print_ast +from graphql.language.ast import ( + DocumentNode, OperationDefinitionNode, FragmentDefinitionNode, FragmentSpreadNode +) +from collections import defaultdict, deque + +def find_fragment_dependencies_recursive(selection_set: SelectionSetNode, fragment_definitions, visited): + """Recursively find all fragment dependencies within a selection set.""" + dependencies = set() + if selection_set is None: + return dependencies + + for selection in selection_set.selections: + # If we encounter a fragment spread, add it to dependencies + if isinstance(selection, FragmentSpreadNode): + spread_name = selection.name.value + dependencies.add(spread_name) + # Recursively add dependencies of the spread fragment + if spread_name in fragment_definitions and spread_name not in visited: + visited.add(spread_name) # Prevent cycles in recursion + fragment = fragment_definitions[spread_name] + dependencies.update( + find_fragment_dependencies_recursive(fragment.selection_set, fragment_definitions, visited) + ) + # If it's a field with a nested selection set, dive deeper + elif isinstance(selection, FieldNode) and selection.selection_set: + dependencies.update( + find_fragment_dependencies_recursive(selection.selection_set, fragment_definitions, visited) + ) + + + return dependencies + +def build_recursive_dependency_graph(document): + """Build a dependency graph for fragments, accounting for deep nested fragment spreads.""" + fragment_definitions = { + definition.name.value: definition for definition in document.definitions + if isinstance(definition, FragmentDefinitionNode) + } + dependencies = defaultdict(set) + + # Populate the dependency graph with deeply nested fragment dependencies + for fragment_name, fragment in fragment_definitions.items(): + visited = set() # Track visited fragments to avoid cyclic dependencies + dependencies[fragment_name] = find_fragment_dependencies_recursive(fragment.selection_set, fragment_definitions, visited) + + return dependencies + + +def topological_sort(dependency_graph): + """Perform a topological sort on fragments based on recursive dependencies.""" + sorted_fragments = [] + no_dependency_fragments = deque([frag for frag, deps in dependency_graph.items() if not deps]) + resolved = set(no_dependency_fragments) + + while no_dependency_fragments: + fragment = no_dependency_fragments.popleft() + sorted_fragments.append(fragment) + + # Remove this fragment from other fragments' dependencies + for frag, deps in dependency_graph.items(): + if fragment in deps: + deps.remove(fragment) + if not deps and frag not in resolved: + no_dependency_fragments.append(frag) + resolved.add(frag) + + # Add any remaining fragments that may have been missed if they were independent + sorted_fragments.extend(frag for frag in dependency_graph if frag not in sorted_fragments) + + return sorted_fragments + from graphql.utilities.type_info import get_field_def import logging @@ -34,13 +109,12 @@ class FragmentsPluginConfig(PluginConfig): - model_config = SettingsConfigDict(env_prefix = "TURMS_PLUGINS_FRAGMENTS_") + model_config = SettingsConfigDict(env_prefix="TURMS_PLUGINS_FRAGMENTS_") type: str = "turms.plugins.fragments.FragmentsPlugin" fragment_bases: List[str] = [] fragments_glob: Optional[str] = None - def get_fragment_bases( config: GeneratorConfig, pluginConfig: FragmentsPluginConfig, @@ -65,6 +139,15 @@ def get_fragment_bases( ] +def get_implementing_types(type: GraphQLInterfaceType, client_schema: GraphQLSchema): + implementing_types = [] + for type in client_schema.get_implementing_types(type): + implementing_types.append(type) + implementing_types += get_implementing_types(type, client_schema) + return implementing_types + + + def generate_fragment( f: FragmentDefinitionNode, client_schema: GraphQLSchema, @@ -83,9 +166,17 @@ def generate_fragment( f.name.value, language.print_ast(f) ) # TODO: Check if typename is being referenced? so that we can check between the elements of the interface + registry.register_fragment_type(f.name.value, type) + + + + if isinstance(type, GraphQLInterfaceType): + + implementing_types = client_schema.get_implementations(type) + mother_class_fields = [] - base_fragment_name = f"{name}" + base_fragment_name = registry.style_fragment_class(f.name.value) additional_bases = get_additional_bases_for_type(type.name, config, registry) if type.description: @@ -93,12 +184,35 @@ def generate_fragment( ast.Expr(value=ast.Constant(value=type.description)) ) - for sub_node in f.selection_set.selections: + sub_nodes = non_typename_fields(f) + + mother_class_name = base_fragment_name + "Base" + + + implementing_class_base_classes = { + } - if isinstance(sub_node, FieldNode): - if sub_node.name.value == "__typename": - continue + inline_fragment_fields = {} + + + for sub_node in sub_nodes: + + if isinstance(sub_node, FragmentSpreadNode): + # Spread nodes are like inheritance? + try: + # We are dealing with a fragment that is an interface + implementation_map = registry.get_interface_fragment_implementations(sub_node.name.value) + for k, v in implementation_map.items(): + implementing_class_base_classes.setdefault(k, []).append(v) + + except KeyError: + x = registry.get_fragment_type(sub_node.name.value) + implementing_class_base_classes.setdefault(x, []).append(registry.inherit_fragment(sub_node.name.value)) + + + + if isinstance(sub_node, FieldNode): field_type = type.fields[sub_node.name.value] mother_class_fields += type_field_node( @@ -111,127 +225,70 @@ def generate_fragment( registry, ) + if isinstance(sub_node, InlineFragmentNode): + on_type_name = sub_node.type_condition.name.value + + inline_fragment_fields.setdefault(on_type_name, []).append( + generate_typename_field( + sub_node.type_condition.name.value, registry, config + ) + ) + + mother_class = ast.ClassDef( - base_fragment_name, - bases=get_interface_bases(config, registry) - + additional_bases, # Todo: fill with base + mother_class_name, + bases=additional_bases + get_interface_bases(config, registry) , # Todo: fill with base decorator_list=[], keywords=[], body=mother_class_fields if mother_class_fields else [ast.Pass()], ) - tree.append(mother_class) + catch_class_name = f"{base_fragment_name}Catch" - union_class_names = [] + catch_class = ast.ClassDef( + catch_class_name, + bases=[ast.Name(id=mother_class_name, ctx=ast.Load())], # Todo: fill with base + decorator_list=[], + keywords=[], + body=[generate_generic_typename_field(registry, config)] + mother_class_fields, + ) - for sub_node in f.selection_set.selections: - if isinstance(sub_node, FragmentSpreadNode): - # Spread nodes are like inheritance? - spreaded_name = f"{base_fragment_name}{sub_node.name.value}" - cls = ast.ClassDef( - spreaded_name, - bases=[ - ast.Name( - id=registry.inherit_fragment(sub_node.name.value), - ctx=ast.Load(), - ), - ast.Name(id=base_fragment_name, ctx=ast.Load()), - ], - decorator_list=[], - keywords=[], - body=[ast.Pass()], - ) + tree.append(mother_class) + tree.append(catch_class) - tree.append(cls) - union_class_names.append(spreaded_name) - if isinstance(sub_node, InlineFragmentNode): - inline_name = ( - f"{base_fragment_name}{sub_node.type_condition.name.value}Fragment" - ) - inline_fragment_fields = [] - inline_fragment_fields += [ - generate_typename_field( - sub_node.type_condition.name.value, registry, config - ) - ] + implementaionMap = {} - for sub_sub_node in sub_node.selection_set.selections: + for i in implementing_types.objects: - if isinstance(sub_sub_node, FieldNode): - sub_sub_node_type = client_schema.get_type( - sub_node.type_condition.name.value - ) + class_name = f"{base_fragment_name}{i.name}" - if sub_sub_node.name.value == "__typename": - continue - - field_type = sub_sub_node_type.fields[sub_sub_node.name.value] - inline_fragment_fields += type_field_node( - sub_sub_node, - inline_name, - field_type, - client_schema, - config, - tree, - registry, - ) - additional_bases = get_additional_bases_for_type( - sub_node.type_condition.name.value, config, registry - ) - cls = ast.ClassDef( - inline_name, - bases=[ - ast.Name(id=base_fragment_name, ctx=ast.Load()), - ] - + additional_bases, + ast_base_nodes = [ast.Name(id=x, ctx=ast.Load()) for x in implementing_class_base_classes.get(i, [])] + implementaionMap[i.name] = class_name + + inline_fields = inline_fragment_fields.get(i, []) + + implementing_class = ast.ClassDef( + class_name, + bases=ast_base_nodes + [ast.Name(id=mother_class_name, ctx=ast.Load())] + get_interface_bases(config, registry), # Todo: fill with base decorator_list=[], keywords=[], - body=inline_fragment_fields - + generate_pydantic_config( - GraphQLTypes.FRAGMENT, - config, - registry, - sub_node.type_condition.name.value, - ), - ) + body=[generate_typename_field(i.name, registry, config)] + inline_fields, + ) - tree.append(cls) - union_class_names.append(inline_name) + tree.append(implementing_class) - union_class_names.append(base_fragment_name) - if len(union_class_names) > 1: - registry.register_import("typing.Union") - slice = ast.Tuple( - elts=[ - ast.Name(id=clsname, ctx=ast.Load()) - for clsname in union_class_names - ], - ctx=ast.Load(), - ) - tree.append( - ast.Assign( - targets=[ - ast.Name( - id=registry.style_fragment_class(f.name.value), - ctx=ast.Store(), - ) - ], - value=ast.Subscript( - value=ast.Name("Union", ctx=ast.Load()), - slice=slice, - ctx=ast.Load(), - ), - ) - ) + registry.register_interface_fragment_implementations(f.name.value, implementaionMap) + return tree + elif isinstance(type, GraphQLObjectType): additional_bases = get_additional_bases_for_type( f.type_condition.name.value, config, registry @@ -249,12 +306,24 @@ def generate_fragment( continue if isinstance(field, FragmentSpreadNode): - additional_bases = [ - ast.Name( - id=registry.inherit_fragment(field.name.value), - ctx=ast.Load(), - ) - ] + additional_bases # needs to be prepended (MRO) + try: + implementationMap = registry.get_interface_fragment_implementations(field.name.value) + if type.name in implementationMap: + additional_bases = [ + ast.Name( + id=implementationMap[type.name], + ctx=ast.Load(), + ) + ] + additional_bases + else: + raise Exception(f"Could not find implementation for {type.name} in {implementationMap}") + except KeyError: + additional_bases = [ + ast.Name( + id=registry.inherit_fragment(field.name.value), + ctx=ast.Load(), + ) + ] + additional_bases # needs to be prepended (MRO) continue field_definition = get_field_def(client_schema, type, field) @@ -277,11 +346,22 @@ def generate_fragment( + get_fragment_bases(config, plugin_config, registry), decorator_list=[], keywords=[], - body=fields + generate_pydantic_config(GraphQLTypes.FRAGMENT, config, registry), + body=fields + + generate_pydantic_config(GraphQLTypes.FRAGMENT, config, registry), ) ) return tree +def reorder_definitions(document, sorted_fragments): + """Reorder document definitions to place fragments in dependency order.""" + fragment_definitions = {defn.name.value: defn for defn in document.definitions if isinstance(defn, FragmentDefinitionNode)} + + # Order fragments according to the topologically sorted order + ordered_fragments = [fragment_definitions[name] for name in sorted_fragments if name in fragment_definitions] + + # Combine operations and ordered fragments + return ordered_fragments + class FragmentsPlugin(Plugin): """Plugin for generating fragments from @@ -312,13 +392,19 @@ def generate_ast( client_schema, self.config.fragments_glob or config.documents ) - definitions = documents.definitions + # Find dependencies and sort fragments topologically + fragment_dependencies = build_recursive_dependency_graph(documents) + + sorted_fragments = topological_sort(fragment_dependencies) + - fragments = [ - node for node in definitions if isinstance(node, FragmentDefinitionNode) - ] - for fragment in fragments: + ordered_fragments = reorder_definitions(documents, sorted_fragments) + + + + + for fragment in ordered_fragments: plugin_tree += generate_fragment( fragment, client_schema, config, self.config, registry ) diff --git a/turms/plugins/funcs.py b/turms/plugins/funcs.py index 73afd9e..8a4bda9 100644 --- a/turms/plugins/funcs.py +++ b/turms/plugins/funcs.py @@ -26,6 +26,7 @@ from turms.registry import ClassRegistry from turms.utils import ( inspect_operation_for_documentation, + non_typename_fields, parse_documents, parse_value_node, recurse_outputtype_annotation, @@ -323,17 +324,23 @@ def get_return_type_annotation( o_name = get_operation_class_name(o, registry) root = get_operation_root_type(client_schema, o) + if collapse is True: + + collapsable_field = o.selection_set.selections[0] + + sub_nodes = non_typename_fields(collapsable_field) field_definition = get_field_def(client_schema, root, collapsable_field) - if collapsable_field.selection_set is None: # pragma: no cover + if len(sub_nodes) == 0: # pragma: no cover return recurse_outputtype_annotation(field_definition.type, registry) + if ( - len(collapsable_field.selection_set.selections) == 1 + len(sub_nodes) == 1 ): # Dealing with one Element - collapsable_fragment_field = collapsable_field.selection_set.selections[0] + collapsable_fragment_field = sub_nodes[0] if isinstance( collapsable_fragment_field, FragmentSpreadNode ): # Dealing with a on element fragment diff --git a/turms/plugins/inputs.py b/turms/plugins/inputs.py index 0bfbfbf..df588cd 100644 --- a/turms/plugins/inputs.py +++ b/turms/plugins/inputs.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from graphql import ( GraphQLInputObjectType, GraphQLInputType, @@ -8,7 +9,7 @@ from pydantic_settings import SettingsConfigDict from turms.plugins.base import Plugin, PluginConfig import ast -from typing import List +from typing import Dict, List, Optional from turms.config import GeneratorConfig from graphql.utilities.build_client_schema import GraphQLSchema from turms.plugins.base import Plugin @@ -27,9 +28,12 @@ class InputsPluginConfig(PluginConfig): - model_config = SettingsConfigDict(extra="forbid", env_prefix="TURMS_PLUGINS_INPUTS_") + model_config = SettingsConfigDict( + extra="forbid", env_prefix="TURMS_PLUGINS_INPUTS_" + ) type: str = "turms.plugins.inputs.InputsPlugin" inputtype_bases: List[str] = ["pydantic.BaseModel"] + allow_population_by_field_name: bool = True skip_underscore: bool = True skip_unreferenced: bool = True @@ -92,7 +96,7 @@ def generate_input_annotation( def list_builder(x): return ast.Subscript( value=ast.Name("Tuple", ctx=ast.Load()), - slice=ast.Tuple(elts=[x, ast.Constant(value=...)], ctx=ast.Load()), + slice=ast.Tuple(elts=[x, ast.Constant(value=...)], ctx=ast.Load()), ctx=ast.Load(), ) @@ -130,6 +134,138 @@ def list_builder(x): raise NotImplementedError(f"Unknown input type {type}") + +@dataclass +class Discriminator: + discriminator: str + value: str + + +def generate_input_type( + name: str, + union_type_descriminators: Dict[str, str], + type: GraphQLInputType, + config: GeneratorConfig, + plugin_config: InputsPluginConfig, + registry: ClassRegistry, + key: str, + discriminator: Optional[Discriminator] = None, +): + + additional_bases = get_additional_bases_for_type(type.name, config, registry) + + fields = ( + [ast.Expr(value=ast.Constant(value=type.description))] + if type.description + else [] + ) + + if discriminator: + fields.append( + ast.AnnAssign( + target=ast.Name(discriminator.discriminator, ctx=ast.Store()), + annotation=ast.Subscript(value=ast.Name("Literal", ctx=ast.Load()), slice=ast.Constant(value=discriminator.value), ctx=ast.Load()), + value=ast.Call( + func=ast.Name(id="Field", ctx=ast.Load()), + args=[], + keywords=[ast.keyword(arg="default", value=ast.Constant(value=discriminator.value))] + ), + simple=1, + ) + ) + + for value_key, value in type.fields.items(): + field_name = registry.generate_node_name(value_key) + + + + + if field_name != value_key: + registry.register_import("pydantic.Field") + + keywords = [ + ast.keyword(arg="alias", value=ast.Constant(value=value_key)) + ] + if not isinstance(value.type, GraphQLNonNull): + keywords.append( + ast.keyword(arg="default", value=ast.Constant(None)) + ) + + + assign = ast.AnnAssign( + target=ast.Name(field_name, ctx=ast.Store()), + annotation=generate_input_annotation( + value.type, + name, + config, + plugin_config, + registry, + is_optional=True, + ), + value=ast.Call( + func=ast.Name(id="Field", ctx=ast.Load()), + args=[], + keywords=keywords, + ), + simple=1, + ) + + else: + assign = ast.AnnAssign( + target=ast.Name(value_key, ctx=ast.Store()), + annotation=generate_input_annotation( + value.type, + name, + config, + plugin_config, + registry, + is_optional=True, + ), + simple=1, + value=( + ast.Constant(None) + if not isinstance(value.type, GraphQLNonNull) + else None + ), + ) + + potential_comment = ( + value.description + if not value.deprecation_reason + else f"DEPRECATED: {value.description}" + ) + + if potential_comment: + fields += [ + assign, + ast.Expr(value=ast.Constant(value=potential_comment)), + ] + + else: + fields += [assign] + + + + return ast.ClassDef( + name, + bases=additional_bases + + [ + ast.Name(id=base.split(".")[-1], ctx=ast.Load()) + for base in plugin_config.inputtype_bases + ], + decorator_list=[], + keywords=[], + body=fields + + generate_pydantic_config( + GraphQLTypes.INPUT, config, registry, typename=key + ), + ) + + + + + + def generate_inputs( client_schema: GraphQLSchema, config: GeneratorConfig, @@ -154,6 +290,48 @@ def generate_inputs( for base in plugin_config.inputtype_bases: registry.register_import(base) + + + union_input_types = {} + union_type_discriminators = {} + + for key, type in inputobjects_type.items(): + directives = type.ast_node.directives if type.ast_node else [] + for directive in directives: + + directive_name = directive.name.value + if directive_name == "unionElementOf": + + union_type = None + discriminator = None + key = None + for arg in directive.arguments: + if arg.name.value == "union": + union_type = arg.value.value + if arg.name.value == "discriminator": + discriminator = arg.value.value + if arg.name.value == "key": + key = arg.value.value + + + if union_type in ref_registry.inputs: + if union_type not in union_input_types: + union_input_types[union_type] = [] + if union_type not in union_type_discriminators: + union_type_discriminators[union_type] = discriminator + + assert union_type_discriminators[union_type] == discriminator, f"Discriminator mismatch for {union_type} expected {union_type_discriminators[union_type]} got {discriminator}" + + name = registry.generate_inputtype(type.name) + union_input_types[union_type].append(name) + tree.append(generate_input_type(name, union_type_discriminators, type, config, plugin_config, registry, type.name, Discriminator(discriminator=discriminator, value=key))) + + + + + + + for key, type in inputobjects_type.items(): if ref_registry and key not in ref_registry.inputs: continue @@ -161,6 +339,56 @@ def generate_inputs( if plugin_config.skip_underscore and key.startswith("_"): # pragma: no cover continue + if type.name in union_input_types: + registry.register_import("typing.Union") + registry.register_import("typing.Annotated") + registry.register_import("pydantic.Field") + union_slice = ast.Tuple( + elts=[ + ast.Name(id=clsname, ctx=ast.Load()) + for clsname in union_input_types[type.name] + ], + ctx=ast.Load(), + ) + + slice = ast.Tuple( + elts=[ + ast.Subscript( + value=ast.Name("Union", ctx=ast.Load()), + slice=union_slice, + ctx=ast.Load(), + ), + ast.Call( + func=ast.Name(id="Field", ctx=ast.Load()), + args=[], + keywords=[ast.keyword(arg="discriminator", value=ast.Constant(union_type_discriminators[type.name]))], + ) + ], + ctx=ast.Load(), + ) + + + + + tree.append( + ast.Assign( + targets=[ + ast.Name( + id=registry.generate_inputtype(type.name), + ctx=ast.Store(), + ) + ], + value=ast.Subscript( + value=ast.Name("Annotated", ctx=ast.Load()), + slice=slice, + ctx=ast.Load(), + ), + ) + ) + + continue + + additional_bases = get_additional_bases_for_type(type.name, config, registry) name = registry.generate_inputtype(key) fields = ( @@ -234,6 +462,8 @@ def generate_inputs( else: fields += [assign] + + tree.append( ast.ClassDef( name, @@ -245,10 +475,17 @@ def generate_inputs( decorator_list=[], keywords=[], body=fields - + generate_pydantic_config(GraphQLTypes.INPUT, config, registry, typename=key), + + generate_pydantic_config( + GraphQLTypes.INPUT, config, registry, typename=key + ), ) ) + + + + + return tree diff --git a/turms/plugins/objects.py b/turms/plugins/objects.py index 6a8329b..ddd95fb 100644 --- a/turms/plugins/objects.py +++ b/turms/plugins/objects.py @@ -30,14 +30,13 @@ class ObjectsPluginConfig(PluginConfig): - model_config = SettingsConfigDict(env_prefix = "TURMS_PLUGINS_OBJECTS_") + model_config = SettingsConfigDict(env_prefix="TURMS_PLUGINS_OBJECTS_") type: str = "turms.plugins.objects.ObjectsPlugin" types_bases: List[str] = ["pydantic.BaseModel"] skip_underscore: bool = False skip_double_underscore: bool = True - def generate_object_field_annotation( graphql_type: GraphQLType, parent: str, @@ -161,7 +160,7 @@ def generate_object_field_annotation( def list_builder(x): return ast.Subscript( value=ast.Name("Tuple", ctx=ast.Load()), - slice=ast.Tuple(elts=[x, ast.Constant(value=...)], ctx=ast.Load()), + slice=ast.Tuple(elts=[x, ast.Constant(value=...)], ctx=ast.Load()), ctx=ast.Load(), ) @@ -287,7 +286,9 @@ def generate_types( keywords = [] if not isinstance(value.type, GraphQLNonNull): - keywords.append(ast.keyword(arg="default", value=ast.Constant(value=None))) + keywords.append( + ast.keyword(arg="default", value=ast.Constant(value=None)) + ) if field_name != value_key: registry.register_import("pydantic.Field") @@ -304,11 +305,11 @@ def generate_types( value=ast.Call( func=ast.Name(id="Field", ctx=ast.Load()), args=[], - keywords=keywords + [ + keywords=keywords + + [ ast.keyword( arg="alias", value=ast.Constant(value=value_key) ) - ], ), simple=1, @@ -326,11 +327,15 @@ def generate_types( registry, is_optional=True, ), - value=ast.Call( - func=ast.Name(id="Field", ctx=ast.Load()), - args=[], - keywords=keywords, - ) if keywords else None, + value=( + ast.Call( + func=ast.Name(id="Field", ctx=ast.Load()), + args=[], + keywords=keywords, + ) + if keywords + else None + ), simple=1, ) @@ -359,7 +364,8 @@ def generate_types( ], decorator_list=[], keywords=[], - body=fields + generate_pydantic_config(GraphQLTypes.OBJECT, config, registry, key), + body=fields + + generate_pydantic_config(GraphQLTypes.OBJECT, config, registry, key), ) ) diff --git a/turms/plugins/operations.py b/turms/plugins/operations.py index 2e06529..bddb4fb 100644 --- a/turms/plugins/operations.py +++ b/turms/plugins/operations.py @@ -18,7 +18,7 @@ from graphql.utilities.type_info import get_field_def import re -from graphql import NonNullTypeNode, language +from graphql import NonNullTypeNode, VariableDefinitionNode, language from turms.registry import ClassRegistry from turms.utils import ( generate_pydantic_config, @@ -36,7 +36,7 @@ class OperationsPluginConfig(PluginConfig): - model_config = SettingsConfigDict(env_prefix = "TURMS_PLUGINS_OPERATIONS_") + model_config = SettingsConfigDict(env_prefix="TURMS_PLUGINS_OPERATIONS_") type: str = "turms.plugins.operations.OperationsPlugin" query_bases: List[str] = [] arguments_bases: List[str] = [] @@ -48,7 +48,6 @@ class OperationsPluginConfig(PluginConfig): arguments_allow_population_by_field_name: bool = False - def get_query_bases( config: GeneratorConfig, plugin_config: OperationsPluginConfig, @@ -101,30 +100,61 @@ def generate_arguments_config( plugin_config: OperationsPluginConfig, registry: ClassRegistry, ): - config_fields = [] - if plugin_config.arguments_allow_population_by_field_name: - config_fields.append( - ast.Assign( - targets=[ - ast.Name(id="allow_population_by_field_name", ctx=ast.Store()) - ], - value=ast.Constant(value=True), + if config.pydantic_version == "1": + config_fields = [] + + if plugin_config.arguments_allow_population_by_field_name: + config_fields.append( + ast.Assign( + targets=[ + ast.Name(id="allow_population_by_field_name", ctx=ast.Store()) + ], + value=ast.Constant(value=True), + ) ) - ) - if len(config_fields) > 0: - return [ - ast.ClassDef( - name="Config", - bases=[], - keywords=[], - body=config_fields, - decorator_list=[], - ) - ] + if len(config_fields) > 0: + return [ + ast.ClassDef( + name="Config", + bases=[], + keywords=[], + body=config_fields, + decorator_list=[], + ) + ] + else: + return [] + else: - return [] + + config_keywords = [] + + if plugin_config.arguments_allow_population_by_field_name is not None: + config_keywords.append( + ast.keyword( + arg="populate_by_name", + value=ast.Constant( + value=config.options.allow_population_by_field_name + ), + ) + ) + + if len(config_keywords) > 0: + registry.register_import("pydantic.ConfigDict") + return [ + ast.Assign( + targets=[ast.Name(id="model_config", ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="ConfigDict", ctx=ast.Load()), + args=[], + keywords=config_keywords, + ), + ) + ] + else: + return [] def get_arguments_bases( @@ -173,6 +203,67 @@ def get_subscription_bases( ] +def generate_expanded_types( + v: VariableDefinitionNode, + client_schema: GraphQLSchema, + config: GeneratorConfig, + plugin_config: OperationsPluginConfig, + registry: ClassRegistry, +): + + if isinstance(v.type, NonNullTypeNode): + type_node = v.type.type + else: + type_node = v.type + + type_name = type_node.name.value + type_definition = client_schema.get_type(type_name) + + if type_definition is None: + return [] + + if type_definition.ast_node is None: + return [] + + if type_definition.ast_node.kind != "InputObjectTypeDefinition": + return [] + + fields = type_definition.ast_node.fields + + fields_body = [] + + additional_keywords = [] + + for field in fields: + field_name = field.name.value + field_type = field.type + + is_optional = not isinstance(field_type, NonNullTypeNode) + annotation = recurse_type_annotation(field_type, registry) + + if is_optional: + assign = ast.AnnAssign( + target=ast.Name(field_name, ctx=ast.Store()), + annotation=annotation, + value=ast.Call( + func=ast.Name(id="Field", ctx=ast.Load()), + args=[], + keywords=[ + ast.keyword(arg="alias", value=ast.Constant(value=field_name)) + ], + ), + simple=1, + ) + else: + assign = ast.AnnAssign( + target=ast.Name(field_name, ctx=ast.Store()), + annotation=annotation, + simple=1, + ) + + fields_body += [assign] + + def generate_operation( o: OperationDefinitionNode, client_schema: GraphQLSchema, @@ -350,7 +441,8 @@ def generate_operation( bases=extra_bases, decorator_list=[], keywords=[], - body=class_body_fields + generate_pydantic_config(o.operation, config, registry), + body=class_body_fields + + generate_pydantic_config(o.operation, config, registry), ) ) diff --git a/turms/plugins/strawberry.py b/turms/plugins/strawberry.py index 3c04eba..76ed7cc 100644 --- a/turms/plugins/strawberry.py +++ b/turms/plugins/strawberry.py @@ -134,7 +134,9 @@ def default_generate_directives( decorator_list=[decorator], keywords=[], body=(fields or [ast.Pass()]) - + generate_pydantic_config(GraphQLTypes.DIRECTIVE, config, registry, directive.name), + + generate_pydantic_config( + GraphQLTypes.DIRECTIVE, config, registry, directive.name + ), ) ) @@ -224,7 +226,7 @@ def default_generate_enums( class StrawberryPluginConfig(PluginConfig): - model_config = SettingsConfigDict(env_prefix = "TURMS_PLUGINS_STRAWBERRY_") + model_config = SettingsConfigDict(env_prefix="TURMS_PLUGINS_STRAWBERRY_") type: str = "turms.plugins.strawberry.Strawberry" generate_directives: bool = True generate_scalars: bool = True @@ -243,7 +245,6 @@ class StrawberryPluginConfig(PluginConfig): generate_enums_func: StrawberryGenerateFunc = default_generate_enums - def generate_object_field_annotation( graphql_type: GraphQLType, parent: str, @@ -674,7 +675,9 @@ def generate_inputs( decorator_list=[decorator], keywords=[], body=fields - + generate_pydantic_config(GraphQLTypes.INPUT, config, registry, typename=key), + + generate_pydantic_config( + GraphQLTypes.INPUT, config, registry, typename=key + ), ) ) @@ -911,7 +914,8 @@ def generate_types( bases=additional_bases, decorator_list=[decorator], keywords=[], - body=fields + generate_pydantic_config(GraphQLTypes.OBJECT, config, registry, key), + body=fields + + generate_pydantic_config(GraphQLTypes.OBJECT, config, registry, key), ) ) diff --git a/turms/processors/base.py b/turms/processors/base.py index ae209e6..1ab6f47 100644 --- a/turms/processors/base.py +++ b/turms/processors/base.py @@ -10,7 +10,6 @@ class ProcessorConfig(BaseSettings): type: str - class Processor(BaseModel): """Base class for all processors diff --git a/turms/recurse.py b/turms/recurse.py index 902b3eb..b33534e 100644 --- a/turms/recurse.py +++ b/turms/recurse.py @@ -7,10 +7,12 @@ ) from turms.registry import ClassRegistry from turms.utils import ( + generate_generic_typename_field, generate_pydantic_config, generate_typename_field, get_additional_bases_for_type, get_interface_bases, + non_typename_fields, target_from_node, ) import ast @@ -73,10 +75,11 @@ class X(BaseModel): union_class_names = [] - for sub_node in node.selection_set.selections: + sub_nodes = non_typename_fields(node) - if isinstance(sub_node, FragmentSpreadNode): + for sub_node in sub_nodes: + if isinstance(sub_node, FragmentSpreadNode): fragment_name = registry.inherit_fragment(sub_node.name.value) union_class_names.append(fragment_name) @@ -172,22 +175,7 @@ class X(BaseModel): mother_class_fields = [] target = target_from_node(node) - # SINGLE SPREAD, AUTO COLLAPSING - if len(node.selection_set.selections) == 1: - # If there is only one field and its a fragment, we can just use the fragment - - subnode = node.selection_set.selections[0] - if isinstance(subnode, FragmentSpreadNode): - if is_optional: - registry.register_import("typing.Optional") - return ast.Subscript( - value=ast.Name("Optional", ctx=ast.Load()), - slice=registry.reference_fragment(subnode.name.value, parent), - ctx=ast.Load(), - ) - - else: - return registry.reference_fragment(subnode.name.value, parent) + sub_nodes = non_typename_fields(node) base_name = f"{parent}{target.capitalize()}" @@ -196,11 +184,34 @@ class X(BaseModel): ast.Expr(value=ast.Constant(value=type.description)) ) - for sub_node in node.selection_set.selections: + implementing_types = client_schema.get_implementations(type) + + + implementing_class_base_classes = { + + } + + inline_fragment_fields = {} + + + + + for sub_node in sub_nodes: + + if isinstance(sub_node, FragmentSpreadNode): + # Spread nodes are like inheritance? + try: + # We are dealing with a fragment that is an interface + implementation_map = registry.get_interface_fragment_implementations(sub_node.name.value) + for k, v in implementation_map.items(): + implementing_class_base_classes.setdefault(k, []).append(v) + + except KeyError: + x = registry.get_fragment_type(sub_node.name.value) + implementing_class_base_classes.setdefault(x.name, []).append(registry.inherit_fragment(sub_node.name.value)) + if isinstance(sub_node, FieldNode): - if sub_node.name.value == "__typename": - continue field_type = type.fields[sub_node.name.value] mother_class_fields += type_field_node( @@ -213,153 +224,118 @@ class X(BaseModel): registry, ) + if isinstance(sub_node, InlineFragmentNode): + + + on_type_name = sub_node.type_condition.name.value + + inline_fragment_fields.setdefault(on_type_name, []).append( + generate_typename_field( + sub_node.type_condition.name.value, registry, config + ) + ) + + + + + # We first genrate the mother class that will provide common fields of this fragment. This will never be reference # though mother_class_name = f"{base_name}Base" additional_bases = get_additional_bases_for_type(type.name, config, registry) + body = mother_class_fields if mother_class_fields else [ast.Pass()] + + mother_class = ast.ClassDef( mother_class_name, - bases=get_interface_bases(config, registry) + additional_bases, + bases=additional_bases + get_interface_bases(config, registry), decorator_list=[], keywords=[], - body=body + generate_pydantic_config(GraphQLTypes.FRAGMENT, config, registry), + body=body + + generate_pydantic_config(GraphQLTypes.FRAGMENT, config, registry), ) subtree.append(mother_class) + implementaionMap = {} union_class_names = [] - for sub_node in node.selection_set.selections: + for i in implementing_types.objects: - if isinstance(sub_node, FragmentSpreadNode): + class_name = f"{mother_class_name}{i.name}" - spreaded_fragment_classname = f"{base_name}{sub_node.name.value}" + ast_base_nodes = [ast.Name(id=x, ctx=ast.Load()) for x in implementing_class_base_classes.get(i.name, [])] + implementaionMap[i.name] = class_name - cls = ast.ClassDef( - spreaded_fragment_classname, - bases=[ - ast.Name(id=mother_class_name, ctx=ast.Load()), - ast.Name( - id=registry.inherit_fragment(sub_node.name.value), - ctx=ast.Load(), - ), - ], + inline_fields = inline_fragment_fields.get(i, []) + + + implementing_class = ast.ClassDef( + class_name, + bases=ast_base_nodes + [ast.Name(id=mother_class_name, ctx=ast.Load())] + get_interface_bases(config, registry), # Todo: fill with base decorator_list=[], keywords=[], - body=[ast.Pass()] - + generate_pydantic_config(GraphQLTypes.FRAGMENT, config, registry), - ) - - subtree.append(cls) - union_class_names.append(spreaded_fragment_classname) - - if isinstance(sub_node, InlineFragmentNode): - inline_fragment_name = ( - f"{base_name}{sub_node.type_condition.name.value}InlineFragment" - ) - inline_fragment_fields = [] + body=[generate_typename_field(i.name, registry, config)] + inline_fields, + ) - inline_fragment_fields += [ - generate_typename_field( - sub_node.type_condition.name.value, registry, config - ) - ] + subtree.append(implementing_class) + union_class_names.append(class_name) - additional_bases = get_additional_bases_for_type( - sub_node.type_condition.name.value, config, registry - ) - for sub_sub_node in sub_node.selection_set.selections: - if isinstance(sub_sub_node, FieldNode): - sub_sub_node_type = client_schema.get_type( - sub_node.type_condition.name.value - ) + registry.register_import("typing.Annotated") + registry.register_import("typing.Union") + union_slice = ast.Tuple( + elts=[ + ast.Name(id=clsname, ctx=ast.Load()) + for clsname in union_class_names + ], + ctx=ast.Load(), + ) - if sub_sub_node.name.value == "__typename": - continue + slice = ast.Tuple( + elts=[ + ast.Subscript( + value=ast.Name("Union", ctx=ast.Load()), + slice=union_slice, + ctx=ast.Load(), + ), + ast.Call( + func=ast.Name(id="Field", ctx=ast.Load()), + args=[], + keywords=[ast.keyword(arg="discriminator", value=ast.Constant("typename"))], + ) + ], + ctx=ast.Load(), + ) - field_type = sub_sub_node_type.fields[sub_sub_node.name.value] - inline_fragment_fields += type_field_node( - sub_sub_node, - inline_fragment_name, - field_type, - client_schema, - config, - subtree, - registry, - ) + # Resort to base class if we have no sub-fragments + annotated_slice = ast.Subscript( + value=ast.Name("Annotated", ctx=ast.Load()), + slice=slice, + ctx=ast.Load(), + ) + - elif isinstance(sub_sub_node, FragmentSpreadNode): - additional_bases.append( - registry.reference_fragment(sub_sub_node.name.value, parent) - ) - cls = ast.ClassDef( - inline_fragment_name, - bases=additional_bases - + [ - ast.Name(id=mother_class_name, ctx=ast.Load()), - ], - decorator_list=[], - keywords=[], - body=inline_fragment_fields - + generate_pydantic_config(GraphQLTypes.FRAGMENT, config, registry), - ) - subtree.append(cls) - union_class_names.append(inline_fragment_name) - if not config.always_resolve_interfaces: - union_class_names.append(mother_class_name) - assert ( - len(union_class_names) != 0 - ), f"You have set 'always_resolve_interfaces' to True but you have no sub-fragments in your query of {base_name}" + if is_optional: + registry.register_import("typing.Optional") - if len(union_class_names) > 1: - registry.register_import("typing.Union") - union_slice = ast.Tuple( - elts=[ - ast.Name(id=clsname, ctx=ast.Load()) - for clsname in union_class_names - ], + return ast.Subscript( + value=ast.Name("Optional", ctx=ast.Load()), + slice=annotated_slice, ctx=ast.Load(), ) - - if is_optional: - registry.register_import("typing.Optional") - - return ast.Subscript( - value=ast.Name("Optional", ctx=ast.Load()), - slice=ast.Subscript( - value=ast.Name("Union", ctx=ast.Load()), - slice=union_slice, - ), - ctx=ast.Load(), - ) - else: - registry.register_import("typing.Union") - return ast.Subscript( - value=ast.Name("Union", ctx=ast.Load()), - slice=union_slice, - ctx=ast.Load(), - ) else: - if is_optional: - registry.register_import("typing.Optional") - - return ast.Subscript( - value=ast.Name("Optional", ctx=ast.Load()), - slice=ast.Name(id=union_class_names[0], ctx=ast.Load()), - ctx=ast.Load(), - ) - return ast.Name(id=union_class_names[0], ctx=ast.Load()) + registry.register_import("typing.Union") + return annotated_slice if isinstance(type, GraphQLObjectType): pick_fields = [] - additional_bases = get_additional_bases_for_type(type.name, config, registry) target = target_from_node(node) object_class_name = f"{parent}{target.capitalize()}" @@ -369,10 +345,14 @@ class X(BaseModel): pick_fields += [generate_typename_field(type.name, registry, config)] + sub_nodes = non_typename_fields(node) + # Single Item collapse - if len(node.selection_set.selections) == 1: - sub_node = node.selection_set.selections[0] + if len(sub_nodes) == 1: + sub_node = sub_nodes[0] + if isinstance(sub_node, FragmentSpreadNode): + if is_optional: registry.register_import("typing.Optional") return ast.Subscript( @@ -388,15 +368,24 @@ class X(BaseModel): sub_node.name.value, parent ) # needs to be parent not object as reference will be to parent - for sub_node in node.selection_set.selections: + + + additional_bases = [] + + for sub_node in sub_nodes: if isinstance(sub_node, FragmentSpreadNode): - additional_bases.append( - ast.Name( - id=registry.inherit_fragment(sub_node.name.value), - ctx=ast.Load(), + + if registry.is_interface_fragment(sub_node.name.value): + raise Exception("Interface Fragments with additional subfields are not yet implemented") + + else: + additional_bases.append( + ast.Name( + id=registry.inherit_fragment(sub_node.name.value), + ctx=ast.Load(), + ) ) - ) if isinstance(sub_node, FieldNode): if sub_node.name.value == "__typename": @@ -414,6 +403,11 @@ class X(BaseModel): if isinstance(sub_node, InlineFragmentNode): raise NotImplementedError("Inline Fragments are not yet implemented") + + if not additional_bases: + # We need to add the base class if we have no fragments + additional_bases = get_additional_bases_for_type(type.name, config, registry) + body = pick_fields if pick_fields else [ast.Pass()] @@ -590,7 +584,8 @@ def type_field_node( value=ast.Call( func=ast.Name(id="Field", ctx=ast.Load()), args=[], - keywords=keywords + [ast.keyword(arg="alias", value=ast.Constant(value=target))], + keywords=keywords + + [ast.keyword(arg="alias", value=ast.Constant(value=target))], ), simple=1, ) @@ -609,11 +604,15 @@ def type_field_node( registry, is_optional=is_optional, ), - value=ast.Call( - func=ast.Name(id="Field", ctx=ast.Load()), - args=[], - keywords=keywords, - ) if keywords else None, + value=( + ast.Call( + func=ast.Name(id="Field", ctx=ast.Load()), + args=[], + keywords=keywords, + ) + if keywords + else None + ), simple=1, ) diff --git a/turms/registry.py b/turms/registry.py index 41ae4aa..02d99c1 100644 --- a/turms/registry.py +++ b/turms/registry.py @@ -101,9 +101,13 @@ def __init__( self.subscription_class_map = {} self.mutation_class_map = {} + self.registered_interfaces_fragments = {} + self.forward_references = set() + self.fragment_type_map = {} self.interfacefragments_class_map = {} + self.interfacefragments_impl_map = {} self.log = log def style_inputtype_class(self, typename: str): @@ -264,10 +268,27 @@ def generate_fragment(self, fragmentname: str, is_interface=False): fragmentname not in self.fragment_class_map ), f"Fragment {fragmentname} was already registered, cannot register annew" classname = self.style_fragment_class(fragmentname) - real_classname = classname if not is_interface else classname + "Base" + real_classname = classname if not is_interface else classname self.fragment_class_map[fragmentname] = real_classname return real_classname + def register_fragment_type(self, fragmentname: str, typename: str): + self.fragment_type_map[fragmentname] = typename + + + def register_interface_fragment_implementations(self, fragmentname: str, implementationMap: Dict[str, str]): + self.interfacefragments_impl_map[fragmentname] = implementationMap + + + def get_interface_fragment_implementations(self, fragmentname: str): + return self.interfacefragments_impl_map[fragmentname] + + + def get_fragment_type(self, fragmentname: str): + return self.fragment_type_map[fragmentname] + + + def reference_fragment( self, typename: str, parent: str, allow_forward=True ) -> ast.AST: @@ -279,6 +300,16 @@ def reference_fragment( "Fragment", allow_forward, ) + + def is_interface_fragment(self, typename: str): + return typename in self.registered_interfaces_fragments + + + def reference_interface_fragment(self, typename: str, parent: str, allow_forward=True) -> ast.AST: + return self.registered_interfaces_fragments[typename] + + def register_interface_fragment(self, typename: str, ast: ast.AST): + self.registered_interfaces_fragments[typename] = ast def inherit_fragment(self, typename: str, allow_forward=True) -> ast.AST: if typename not in self.fragment_class_map: @@ -439,7 +470,11 @@ def generate_forward_refs(self): id=reference, ctx=ast.Load(), ), - attr="model_rebuild" if self.config.pydantic_version == "v2" else "update_forward_refs", + attr=( + "model_rebuild" + if self.config.pydantic_version == "v2" + else "update_forward_refs" + ), ctx=ast.Load(), ), keywords=[], diff --git a/turms/run.py b/turms/run.py index bb98662..843f5fe 100644 --- a/turms/run.py +++ b/turms/run.py @@ -349,19 +349,25 @@ def log(x, **kwargs): processors = [] for parser_config in gen_config.parsers: - styler = instantiate(parser_config.type, config=parser_config.model_dump(), log=log) + styler = instantiate( + parser_config.type, config=parser_config.model_dump(), log=log + ) if verbose: get_console().print(f"Using Parser {styler}") parsers.append(styler) for plugins_config in gen_config.plugins: - styler = instantiate(plugins_config.type, config=plugins_config.model_dump(), log=log) + styler = instantiate( + plugins_config.type, config=plugins_config.model_dump(), log=log + ) if verbose: get_console().print(f"Using Plugin {styler}") plugins.append(styler) for styler_config in gen_config.stylers: - styler = instantiate(styler_config.type, config=styler_config.model_dump(), log=log) + styler = instantiate( + styler_config.type, config=styler_config.model_dump(), log=log + ) if verbose: get_console().print(f"Using Styler {styler}") stylers.append(styler) diff --git a/turms/stylers/base.py b/turms/stylers/base.py index 9ffeecd..08bdd96 100644 --- a/turms/stylers/base.py +++ b/turms/stylers/base.py @@ -9,7 +9,6 @@ class StylerConfig(BaseSettings): type: str - class Styler(BaseModel): """Base class for all stylers diff --git a/turms/utils.py b/turms/utils.py index ad99227..9974640 100644 --- a/turms/utils.py +++ b/turms/utils.py @@ -4,11 +4,12 @@ from turms.config import GeneratorConfig from turms.errors import GenerationError from graphql.utilities.build_client_schema import GraphQLSchema -from graphql.language.ast import DocumentNode, FieldNode +from graphql.language.ast import DocumentNode, FieldNode, NameNode from graphql.error.graphql_error import GraphQLError from graphql import ( BooleanValueNode, FloatValueNode, + FragmentDefinitionNode, GraphQLEnumType, GraphQLList, GraphQLNonNull, @@ -21,9 +22,11 @@ NonNullTypeNode, NullValueNode, OperationDefinitionNode, + SelectionSetNode, StringValueNode, ValueNode, parse, + print_ast, validate, GraphQLInterfaceType, ) @@ -65,6 +68,13 @@ def target_from_node(node: FieldNode) -> str: ) +def non_typename_fields(node: FieldNode) -> List[FieldNode]: + """Returns all fields in a FieldNode that are not __typename""" + if not node.selection_set: + return [] + return [field for field in node.selection_set.selections if not (isinstance(field, FieldNode) and field.name.value == "__typename")] + + def inspect_operation_for_documentation(operation: OperationDefinitionNode): """Checks for operation level documentatoin""" @@ -104,14 +114,36 @@ def generate_typename_field( return ast.AnnAssign( target=ast.Name(id="typename", ctx=ast.Store()), annotation=ast.Subscript( - value=ast.Name(id="Optional", ctx=ast.Load()), - slice=ast.Subscript( value=ast.Name("Literal", ctx=ast.Load()), slice=ast.Constant(value=typename), ctx=ast.Load(), ), - ctx=ast.Load(), + value=ast.Call( + func=ast.Name(id="Field", ctx=ast.Load()), + args=[], + keywords=keywords, ), + simple=1, + ) + +def generate_generic_typename_field( + registry: ClassRegistry, config: GeneratorConfig +): + """Generates the typename field a specific type, this will be used to determine the type of the object in the response""" + + registry.register_import("pydantic.Field") + registry.register_import("typing.Optional") + registry.register_import("typing.Literal") + + keywords = [ + ast.keyword(arg="alias", value=ast.Constant(value="__typename")), + ] + if config.exclude_typenames: + keywords.append(ast.keyword(arg="exclude", value=ast.Constant(value=True))) + + return ast.AnnAssign( + target=ast.Name(id="typename", ctx=ast.Store()), + annotation=ast.Name("str", ctx=ast.Load()), value=ast.Call( func=ast.Name(id="Field", ctx=ast.Load()), args=[], @@ -120,8 +152,12 @@ def generate_typename_field( simple=1, ) + def generate_config_dict( - graphQLType: GraphQLTypes, config: GeneratorConfig, registy: ClassRegistry, typename: str = None + graphQLType: GraphQLTypes, + config: GeneratorConfig, + registy: ClassRegistry, + typename: str = None, ): """Generates the config class for a specific type version 2 @@ -156,35 +192,55 @@ def generate_config_dict( else: if config.options.allow_mutation is not None: config_keywords.append( - ast.keyword(arg="allow_mutation", value=ast.Constant(value=config.options.allow_mutation)) + ast.keyword( + arg="allow_mutation", + value=ast.Constant(value=config.options.allow_mutation), + ) ) if config.options.extra is not None: config_keywords.append( - ast.keyword(arg="extra", value=ast.Constant(value=config.options.extra)) + ast.keyword( + arg="extra", value=ast.Constant(value=config.options.extra) + ) ) if config.options.validate_assignment is not None: config_keywords.append( - ast.keyword(arg="validate_assignment", value=ast.Constant(value=config.options.validate_assignment)) + ast.keyword( + arg="validate_assignment", + value=ast.Constant( + value=config.options.validate_assignment + ), + ) ) if config.options.allow_population_by_field_name is not None: config_keywords.append( - ast.keyword(arg="populate_by_name", value=ast.Constant(value=config.options.allow_population_by_field_name)) + ast.keyword( + arg="populate_by_name", + value=ast.Constant( + value=config.options.allow_population_by_field_name + ), + ) ) if config.options.orm_mode is not None: config_keywords.append( - ast.keyword(arg="orm_mode", value=ast.Constant(value=config.options.orm_mode)) + ast.keyword( + arg="orm_mode", + value=ast.Constant(value=config.options.orm_mode), + ) ) if config.options.use_enum_values is not None: config_keywords.append( - ast.keyword(arg="use_enum_values", value=ast.Constant(value=config.options.use_enum_values)) + ast.keyword( + arg="use_enum_values", + value=ast.Constant(value=config.options.use_enum_values), + ) ) - if typename: if typename in config.additional_config: for key, value in config.additional_config[typename].items(): @@ -192,16 +248,17 @@ def generate_config_dict( ast.keyword(arg=key, value=ast.Constant(value=value)) ) - if len(config_keywords) > 0: registy.register_import("pydantic.ConfigDict") return [ ast.Assign( targets=[ast.Name(id="model_config", ctx=ast.Store())], - value=ast.Call(func=ast.Name(id="ConfigDict", ctx=ast.Load()), - args=[], keywords=config_keywords) + value=ast.Call( + func=ast.Name(id="ConfigDict", ctx=ast.Load()), + args=[], + keywords=config_keywords, + ), ) - ] else: return [] @@ -331,15 +388,60 @@ def generate_config_class_pydantic( ] else: return [] - -def generate_pydantic_config(graphQLType: GraphQLTypes, config: GeneratorConfig, registry: ClassRegistry, typename: str = None): - if config.pydantic_version == "v2": + +def generate_pydantic_config( + graphQLType: GraphQLTypes, + config: GeneratorConfig, + registry: ClassRegistry, + typename: str = None, +): + if config.pydantic_version == "v2": return generate_config_dict(graphQLType, config, registry, typename) else: return generate_config_class_pydantic(graphQLType, config, typename) + +def add_typename_recursively(selection_set: SelectionSetNode, skip=False) -> None: + if selection_set is None: + return + + # Collect all existing fields in the selection set + selections = list(selection_set.selections) + has_typename = any( + isinstance(field, FieldNode) and field.name.value == "__typename" + for field in selections + ) + + # Add __typename if it's not already present + if not has_typename and not skip: + selections.append( + FieldNode( + name=NameNode(value="__typename"), + arguments=[], + directives=[], + selection_set=None, + ) + ) + + # Apply the function recursively to nested selection sets + for field in selections: + if isinstance(field, FieldNode) and field.selection_set: + add_typename_recursively(field.selection_set) + + # Update the selection set with potentially added __typename fields + selection_set.selections = tuple(selections) + +def auto_add_typename_field_to_all_objects(document: DocumentNode) -> DocumentNode: + for definition in document.definitions: + if isinstance(definition, (OperationDefinitionNode, FragmentDefinitionNode)): + add_typename_recursively(definition.selection_set, skip=isinstance(definition, OperationDefinitionNode)) + + + return document + + def parse_documents(client_schema: GraphQLSchema, scan_glob) -> DocumentNode: """ """ if not scan_glob: @@ -367,6 +469,11 @@ def parse_documents(client_schema: GraphQLSchema, scan_glob) -> DocumentNode: raise InvalidDocuments( "Invalid Documents \n" + "\n".join(str(e) for e in errors) ) + + + nodes = auto_add_typename_field_to_all_objects(nodes) + + return nodes @@ -374,6 +481,34 @@ def parse_documents(client_schema: GraphQLSchema, scan_glob) -> DocumentNode: fragment_searcher = re.compile(r"\.\.\.(?P[a-zA-Z]*)") + +def auto_add_typename_field_to_fragment_str(fragment_str: str) -> str: + x = parse(fragment_str) + for fragment in x.definitions: + if isinstance(fragment, FragmentDefinitionNode): + selections = list(fragment.selection_set.selections) + if not any(field.name.value == "__typename" for field in selections): + selections.append( + FieldNode( + name=NameNode(value="__typename"), + arguments=[], + directives=[], + selection_set=None, + ) + ) + fragment.selection_set.selections = tuple(selections) + + + + + + return print_ast(x) + + + + + + def replace_iteratively( pattern, registry, @@ -386,7 +521,7 @@ def replace_iteratively( else: try: level_down_pattern = "\n\n".join( - [registry.get_fragment_document(key) for key in new_fragments] + [auto_add_typename_field_to_fragment_str(registry.get_fragment_document(key)) for key in new_fragments] + [pattern] ) return replace_iteratively(