Skip to content

Commit ab74435

Browse files
committed
Use a list of fields and return type instead of annotations dictionary.
Use `__class__` as the first argument name so `_update_func_cell_for__class__` can update it.
1 parent 9c6ed47 commit ab74435

1 file changed

Lines changed: 29 additions & 41 deletions

File tree

Lib/dataclasses.py

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,9 @@ def __init__(self, globals):
443443
self.unconditional_adds = {}
444444
self.method_annotations = {}
445445

446-
def add_fn(self, name, args, body, *, locals=None, annotations=None,
447-
overwrite_error=False, unconditional_add=False, decorator=None):
446+
def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
447+
overwrite_error=False, unconditional_add=False, decorator=None,
448+
annotation_fields=None):
448449
if locals is not None:
449450
self.locals.update(locals)
450451

@@ -465,8 +466,8 @@ def add_fn(self, name, args, body, *, locals=None, annotations=None,
465466

466467
self.names.append(name)
467468

468-
if annotations is not None:
469-
self.method_annotations[name] = annotations
469+
if annotation_fields is not None:
470+
self.method_annotations[name] = (annotation_fields, return_type)
470471

471472
args = ','.join(args)
472473
body = '\n'.join(body)
@@ -508,10 +509,13 @@ def add_fns_to_class(self, cls):
508509
# Now that we've generated the functions, assign them into cls.
509510
for name, fn in zip(self.names, fns):
510511
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
511-
if annotations := self.method_annotations.get(name):
512-
annotate_fn = _make_annotate_function(cls, annotations)
513-
annotate_fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}.__annotate__"
514512

513+
try:
514+
annotation_fields, return_type = self.method_annotations[name]
515+
except KeyError:
516+
pass
517+
else:
518+
annotate_fn = _make_annotate_function(cls, name, annotation_fields, return_type)
515519
fn.__annotate__ = annotate_fn
516520

517521
if self.unconditional_adds.get(name, False):
@@ -529,7 +533,7 @@ def add_fns_to_class(self, cls):
529533
raise TypeError(error_msg)
530534

531535

532-
def _make_annotate_function(cls, annotations):
536+
def _make_annotate_function(__class__, method_name, annotation_fields, return_type):
533537
# Create an __annotate__ function for a dataclass
534538
# Try to return annotations in the same format as they would be
535539
# from a regular __init__ function
@@ -541,21 +545,20 @@ def __annotate__(format, /):
541545
match format:
542546
case Format.VALUE | Format.FORWARDREF | Format.STRING:
543547
cls_annotations = {}
544-
for base in reversed(cls.__mro__):
548+
for base in reversed(__class__.__mro__):
545549
cls_annotations.update(
546550
annotationlib.get_annotations(base, format=format)
547551
)
548552

549553
new_annotations = {}
550-
for k, v in annotations.items():
551-
try:
552-
new_annotations[k] = cls_annotations[k]
553-
except KeyError:
554-
# This should be the return value
555-
if format == Format.STRING:
556-
new_annotations[k] = annotationlib.type_repr(v)
557-
else:
558-
new_annotations[k] = v
554+
for k in annotation_fields:
555+
new_annotations[k] = cls_annotations[k]
556+
557+
if return_type is not MISSING:
558+
if format == Format.STRING:
559+
new_annotations["return"] = annotationlib.type_repr(return_type)
560+
else:
561+
new_annotations["return"] = return_type
559562

560563
return new_annotations
561564

@@ -565,6 +568,7 @@ def __annotate__(format, /):
565568
# This is a flag for _add_slots to know it needs to regenerate this method
566569
# In order to remove references to the original class when it is replaced
567570
__annotate__._generated_by_dataclasses = True
571+
__annotate__.__qualname__ = f"{__class__.__qualname__}.{method_name}.__annotate__"
568572

569573
return __annotate__
570574

@@ -680,8 +684,7 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
680684
raise TypeError(f'non-default argument {f.name!r} '
681685
f'follows default argument {seen_default.name!r}')
682686

683-
annotations = {f.name: f.type for f in fields if f.init}
684-
annotations["return"] = None
687+
annotation_fields = [f.name for f in fields if f.init]
685688

686689
locals = {'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
687690
'__dataclass_builtins_object__': object}
@@ -715,7 +718,8 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
715718
[self_name] + _init_params,
716719
body_lines,
717720
locals=locals,
718-
annotations=annotations)
721+
return_type=None,
722+
annotation_fields=annotation_fields)
719723

720724

721725
def _frozen_get_del_attr(cls, fields, func_builder):
@@ -1395,26 +1399,10 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
13951399
else:
13961400
f.type = ann
13971401

1398-
# Fix references in generated __annotate__ methods
1399-
method = getattr(newcls, "__init__")
1400-
update_annotations = getattr(method.__annotate__, "_generated_by_dataclasses", False)
1401-
1402-
if update_annotations:
1403-
new_annotations = {}
1404-
1405-
# Get the previous annotations to know what to replace
1406-
old_annotations = method.__annotate__(annotationlib.Format.FORWARDREF)
1407-
1408-
for k, v in old_annotations.items():
1409-
try:
1410-
new_annotations[k] = newcls_ann[k]
1411-
except KeyError:
1412-
new_annotations[k] = v
1413-
1414-
new_annotate = _make_annotate_function(newcls, new_annotations)
1415-
new_annotate.__qualname__ = f"{newcls.__qualname__}.__init__.__annotate__"
1416-
1417-
setattr(method, "__annotate__", new_annotate)
1402+
# Fix the class reference in the __annotate__ method
1403+
init_annotate = newcls.__init__.__annotate__
1404+
if getattr(init_annotate, "_generated_by_dataclasses", False):
1405+
_update_func_cell_for__class__(init_annotate, cls, newcls)
14181406

14191407
return newcls
14201408

0 commit comments

Comments
 (0)