@@ -441,8 +441,9 @@ def __init__(self, globals):
441441 self .locals = {}
442442 self .overwrite_errors = {}
443443 self .unconditional_adds = {}
444+ self .method_annotations = {}
444445
445- def add_fn (self , name , args , body , * , locals = None , return_type = MISSING ,
446+ def add_fn (self , name , args , body , * , locals = None , annotations = None ,
446447 overwrite_error = False , unconditional_add = False , decorator = None ):
447448 if locals is not None :
448449 self .locals .update (locals )
@@ -464,16 +465,14 @@ def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
464465
465466 self .names .append (name )
466467
467- if return_type is not MISSING :
468- self .locals [f'__dataclass_{ name } _return_type__' ] = return_type
469- return_annotation = f'->__dataclass_{ name } _return_type__'
470- else :
471- return_annotation = ''
468+ if annotations is not None :
469+ self .method_annotations [name ] = annotations
470+
472471 args = ',' .join (args )
473472 body = '\n ' .join (body )
474473
475474 # Compute the text of the entire function, add it to the text we're generating.
476- self .src .append (f'{ f' { decorator } \n ' if decorator else '' } def { name } ({ args } ){ return_annotation } :\n { body } ' )
475+ self .src .append (f'{ f' { decorator } \n ' if decorator else '' } def { name } ({ args } ):\n { body } ' )
477476
478477 def add_fns_to_class (self , cls ):
479478 # The source to all of the functions we're generating.
@@ -509,6 +508,9 @@ def add_fns_to_class(self, cls):
509508 # Now that we've generated the functions, assign them into cls.
510509 for name , fn in zip (self .names , fns ):
511510 fn .__qualname__ = f"{ cls .__qualname__ } .{ fn .__name__ } "
511+ if annotations := self .method_annotations .get (name ):
512+ fn .__annotate__ = self .make_annotate_function (annotations )
513+
512514 if self .unconditional_adds .get (name , False ):
513515 setattr (cls , name , fn )
514516 else :
@@ -523,6 +525,34 @@ def add_fns_to_class(self, cls):
523525
524526 raise TypeError (error_msg )
525527
528+ @staticmethod
529+ def make_annotate_function (annotations ):
530+ # Create an __annotate__ function for a dataclass
531+ # Try to return annotations in the same format as they would be
532+ # from a regular __init__ function
533+ def __annotate__ (format ):
534+ match format :
535+ case annotationlib .Format .VALUE | annotationlib .Format .FORWARDREF :
536+ return {
537+ k : v .evaluate (format = format )
538+ if isinstance (v , annotationlib .ForwardRef ) else v
539+ for k , v in annotations .items ()
540+ }
541+ case annotationlib .Format .STRING :
542+ string_annos = {}
543+ for k , v in annotations .items ():
544+ if isinstance (v , str ):
545+ string_annos [k ] = v
546+ elif isinstance (v , annotationlib .ForwardRef ):
547+ string_annos [k ] = v .evaluate (format = annotationlib .Format .STRING )
548+ else :
549+ string_annos [k ] = annotationlib .type_repr (v )
550+ return string_annos
551+ case _:
552+ raise NotImplementedError (format )
553+
554+ return __annotate__
555+
526556
527557def _field_assign (frozen , name , value , self_name ):
528558 # If we're a frozen class, then assign to our fields in __init__
@@ -612,7 +642,7 @@ def _init_param(f):
612642 elif f .default_factory is not MISSING :
613643 # There's a factory function. Set a marker.
614644 default = '=__dataclass_HAS_DEFAULT_FACTORY__'
615- return f'{ f .name } :__dataclass_type_ { f . name } __ { default } '
645+ return f'{ f .name } { default } '
616646
617647
618648def _init_fn (fields , std_fields , kw_only_fields , frozen , has_post_init ,
@@ -635,8 +665,10 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
635665 raise TypeError (f'non-default argument { f .name !r} '
636666 f'follows default argument { seen_default .name !r} ' )
637667
638- locals = {** {f'__dataclass_type_{ f .name } __' : f .type for f in fields },
639- ** {'__dataclass_HAS_DEFAULT_FACTORY__' : _HAS_DEFAULT_FACTORY ,
668+ annotations = {f .name : f .type for f in fields if f .init }
669+ annotations ["return" ] = None
670+
671+ locals = {** {'__dataclass_HAS_DEFAULT_FACTORY__' : _HAS_DEFAULT_FACTORY ,
640672 '__dataclass_builtins_object__' : object ,
641673 }
642674 }
@@ -670,7 +702,7 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
670702 [self_name ] + _init_params ,
671703 body_lines ,
672704 locals = locals ,
673- return_type = None )
705+ annotations = annotations )
674706
675707
676708def _frozen_get_del_attr (cls , fields , func_builder ):
0 commit comments