@@ -50,7 +50,9 @@ class GraphQLStrategy:
5050 # This is a per-method cache without limits as they are proportionate to the schema size
5151 _cache : Dict [str , Dict ] = attr .ib (factory = dict )
5252
53- def values (self , type_ : graphql .GraphQLInputType ) -> st .SearchStrategy [InputTypeNode ]:
53+ def values (
54+ self , type_ : graphql .GraphQLInputType , default : Optional [graphql .ValueNode ] = None
55+ ) -> st .SearchStrategy [InputTypeNode ]:
5456 """Generate value nodes of a type, that corresponds to the input type.
5557
5658 They correspond to all `GraphQLInputType` variants:
@@ -67,23 +69,25 @@ def values(self, type_: graphql.GraphQLInputType) -> st.SearchStrategy[InputType
6769 if isinstance (type_ , graphql .GraphQLScalarType ):
6870 type_name = type_ .name
6971 if type_name in self .custom_scalars :
70- return primitives .maybe_null (self .custom_scalars [type_name ], nullable )
71- return primitives .scalar (type_name , nullable )
72+ return primitives .custom (self .custom_scalars [type_name ], nullable , default = default )
73+ return primitives .scalar (type_name , nullable , default = default )
7274 if isinstance (type_ , graphql .GraphQLEnumType ):
7375 values = tuple (type_ .values )
74- return primitives .enum (values , nullable )
76+ return primitives .enum (values , nullable , default = default )
7577 # Types with children
7678 if isinstance (type_ , graphql .GraphQLList ):
77- return self .lists (type_ , nullable )
79+ return self .lists (type_ , nullable , default = default )
7880 if isinstance (type_ , graphql .GraphQLInputObjectType ):
7981 return self .objects (type_ , nullable )
8082 raise TypeError (f"Type { type_ .__class__ .__name__ } is not supported." )
8183
82- @instance_cache (lambda type_ , nullable = True : (make_type_name (type_ ), nullable ))
83- def lists (self , type_ : graphql .GraphQLList , nullable : bool = True ) -> st .SearchStrategy [graphql .ListValueNode ]:
84+ @instance_cache (lambda type_ , nullable = True , default = None : (make_type_name (type_ ), nullable , default ))
85+ def lists (
86+ self , type_ : graphql .GraphQLList , nullable : bool = True , default : Optional [graphql .ValueNode ] = None
87+ ) -> st .SearchStrategy [graphql .ListValueNode ]:
8488 """Generate a `graphql.ListValueNode`."""
8589 strategy = st .lists (self .values (type_ .of_type ))
86- return primitives .maybe_null (strategy . map ( nodes . List ) , nullable )
90+ return primitives .list_ (strategy , nullable , default = default )
8791
8892 @instance_cache (lambda type_ , nullable = True : (type_ .name , nullable ))
8993 def objects (
@@ -113,11 +117,16 @@ def can_generate_field(self, field: graphql.GraphQLInputField) -> bool:
113117 )
114118
115119 def lists_of_object_fields (
116- self , items : List [Tuple [str , Field ]]
120+ self , items : List [Tuple [str , graphql . GraphQLInputField ]]
117121 ) -> st .SearchStrategy [List [graphql .ObjectFieldNode ]]:
118- return st .tuples (* (self .values (field .type ).map (factories .object_field (name )) for name , field in items )).map (
119- list
120- )
122+ return st .tuples (
123+ * (
124+ self .values (field .type , field .ast_node .default_value if field .ast_node is not None else None ).map (
125+ factories .object_field (name )
126+ )
127+ for name , field in items
128+ )
129+ ).map (list )
121130
122131 @instance_cache (lambda interface , implementations : (interface .name , tuple (impl .name for impl in implementations )))
123132 def interfaces (
@@ -202,8 +211,9 @@ def list_of_arguments(
202211 def inner (draw : Any ) -> List [graphql .ArgumentNode ]:
203212 args = []
204213 for name , argument in arguments .items ():
214+ default = argument .ast_node .default_value if argument .ast_node is not None else None
205215 try :
206- argument_strategy = self .values (argument .type )
216+ argument_strategy = self .values (argument .type , default = default )
207217 except InvalidArgument :
208218 if not isinstance (argument .type , graphql .GraphQLNonNull ):
209219 # If the type is nullable, then either generate `null` or skip it completely
@@ -305,8 +315,8 @@ def add_alias(frag: graphql.InlineFragmentNode) -> graphql.InlineFragmentNode:
305315
306316
307317def subset_of_fields (
308- fields : Dict [str , Field ], * , force_required : bool = False
309- ) -> st .SearchStrategy [List [Tuple [str , Field ]]]:
318+ fields : Dict [str , graphql . GraphQLInputField ], * , force_required : bool = False
319+ ) -> st .SearchStrategy [List [Tuple [str , graphql . GraphQLInputField ]]]:
310320 """A helper to select a subset of fields."""
311321 field_pairs = sorted (fields .items ())
312322 # if we need to always generate required fields, then return them and extend with a subset of optional fields
0 commit comments