@@ -443,8 +443,9 @@ def __init__(self, globals):
443443 self .unconditional_adds = {}
444444 self .method_annotations = {}
445445
446- def add_fn (self , name , args , body , * , locals = None , annotations = None ,
447- overwrite_error = False , unconditional_add = False , decorator = None ):
446+ def add_fn (self , name , args , body , * , locals = None , return_type = MISSING ,
447+ overwrite_error = False , unconditional_add = False , decorator = None ,
448+ annotation_fields = None ):
448449 if locals is not None :
449450 self .locals .update (locals )
450451
@@ -465,8 +466,8 @@ def add_fn(self, name, args, body, *, locals=None, annotations=None,
465466
466467 self .names .append (name )
467468
468- if annotations is not None :
469- self .method_annotations [name ] = annotations
469+ if annotation_fields is not None :
470+ self .method_annotations [name ] = ( annotation_fields , return_type )
470471
471472 args = ',' .join (args )
472473 body = '\n ' .join (body )
@@ -508,10 +509,13 @@ def add_fns_to_class(self, cls):
508509 # Now that we've generated the functions, assign them into cls.
509510 for name , fn in zip (self .names , fns ):
510511 fn .__qualname__ = f"{ cls .__qualname__ } .{ fn .__name__ } "
511- if annotations := self .method_annotations .get (name ):
512- annotate_fn = _make_annotate_function (cls , annotations )
513- annotate_fn .__qualname__ = f"{ cls .__qualname__ } .{ fn .__name__ } .__annotate__"
514512
513+ try :
514+ annotation_fields , return_type = self .method_annotations [name ]
515+ except KeyError :
516+ pass
517+ else :
518+ annotate_fn = _make_annotate_function (cls , name , annotation_fields , return_type )
515519 fn .__annotate__ = annotate_fn
516520
517521 if self .unconditional_adds .get (name , False ):
@@ -529,7 +533,7 @@ def add_fns_to_class(self, cls):
529533 raise TypeError (error_msg )
530534
531535
532- def _make_annotate_function (cls , annotations ):
536+ def _make_annotate_function (__class__ , method_name , annotation_fields , return_type ):
533537 # Create an __annotate__ function for a dataclass
534538 # Try to return annotations in the same format as they would be
535539 # from a regular __init__ function
@@ -541,21 +545,20 @@ def __annotate__(format, /):
541545 match format :
542546 case Format .VALUE | Format .FORWARDREF | Format .STRING :
543547 cls_annotations = {}
544- for base in reversed (cls .__mro__ ):
548+ for base in reversed (__class__ .__mro__ ):
545549 cls_annotations .update (
546550 annotationlib .get_annotations (base , format = format )
547551 )
548552
549553 new_annotations = {}
550- for k , v in annotations .items ():
551- try :
552- new_annotations [k ] = cls_annotations [k ]
553- except KeyError :
554- # This should be the return value
555- if format == Format .STRING :
556- new_annotations [k ] = annotationlib .type_repr (v )
557- else :
558- new_annotations [k ] = v
554+ for k in annotation_fields :
555+ new_annotations [k ] = cls_annotations [k ]
556+
557+ if return_type is not MISSING :
558+ if format == Format .STRING :
559+ new_annotations ["return" ] = annotationlib .type_repr (return_type )
560+ else :
561+ new_annotations ["return" ] = return_type
559562
560563 return new_annotations
561564
@@ -565,6 +568,7 @@ def __annotate__(format, /):
565568 # This is a flag for _add_slots to know it needs to regenerate this method
566569 # In order to remove references to the original class when it is replaced
567570 __annotate__ ._generated_by_dataclasses = True
571+ __annotate__ .__qualname__ = f"{ __class__ .__qualname__ } .{ method_name } .__annotate__"
568572
569573 return __annotate__
570574
@@ -680,8 +684,7 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
680684 raise TypeError (f'non-default argument { f .name !r} '
681685 f'follows default argument { seen_default .name !r} ' )
682686
683- annotations = {f .name : f .type for f in fields if f .init }
684- annotations ["return" ] = None
687+ annotation_fields = [f .name for f in fields if f .init ]
685688
686689 locals = {'__dataclass_HAS_DEFAULT_FACTORY__' : _HAS_DEFAULT_FACTORY ,
687690 '__dataclass_builtins_object__' : object }
@@ -715,7 +718,8 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
715718 [self_name ] + _init_params ,
716719 body_lines ,
717720 locals = locals ,
718- annotations = annotations )
721+ return_type = None ,
722+ annotation_fields = annotation_fields )
719723
720724
721725def _frozen_get_del_attr (cls , fields , func_builder ):
@@ -1395,26 +1399,10 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
13951399 else :
13961400 f .type = ann
13971401
1398- # Fix references in generated __annotate__ methods
1399- method = getattr (newcls , "__init__" )
1400- update_annotations = getattr (method .__annotate__ , "_generated_by_dataclasses" , False )
1401-
1402- if update_annotations :
1403- new_annotations = {}
1404-
1405- # Get the previous annotations to know what to replace
1406- old_annotations = method .__annotate__ (annotationlib .Format .FORWARDREF )
1407-
1408- for k , v in old_annotations .items ():
1409- try :
1410- new_annotations [k ] = newcls_ann [k ]
1411- except KeyError :
1412- new_annotations [k ] = v
1413-
1414- new_annotate = _make_annotate_function (newcls , new_annotations )
1415- new_annotate .__qualname__ = f"{ newcls .__qualname__ } .__init__.__annotate__"
1416-
1417- setattr (method , "__annotate__" , new_annotate )
1402+ # Fix the class reference in the __annotate__ method
1403+ init_annotate = newcls .__init__ .__annotate__
1404+ if getattr (init_annotate , "_generated_by_dataclasses" , False ):
1405+ _update_func_cell_for__class__ (init_annotate , cls , newcls )
14181406
14191407 return newcls
14201408
0 commit comments