Skip to content

Commit 5820f94

Browse files
committed
replace source annotations in init with attached annotate function
1 parent be56464 commit 5820f94

1 file changed

Lines changed: 43 additions & 11 deletions

File tree

Lib/dataclasses.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,9 @@ def __init__(self, globals):
441441
self.locals = {}
442442
self.overwrite_errors = {}
443443
self.unconditional_adds = {}
444+
self.method_annotations = {}
444445

445-
def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
446+
def add_fn(self, name, args, body, *, locals=None, annotations=None,
446447
overwrite_error=False, unconditional_add=False, decorator=None):
447448
if locals is not None:
448449
self.locals.update(locals)
@@ -464,16 +465,14 @@ def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
464465

465466
self.names.append(name)
466467

467-
if return_type is not MISSING:
468-
self.locals[f'__dataclass_{name}_return_type__'] = return_type
469-
return_annotation = f'->__dataclass_{name}_return_type__'
470-
else:
471-
return_annotation = ''
468+
if annotations is not None:
469+
self.method_annotations[name] = annotations
470+
472471
args = ','.join(args)
473472
body = '\n'.join(body)
474473

475474
# Compute the text of the entire function, add it to the text we're generating.
476-
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
475+
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}):\n{body}')
477476

478477
def add_fns_to_class(self, cls):
479478
# The source to all of the functions we're generating.
@@ -509,6 +508,9 @@ def add_fns_to_class(self, cls):
509508
# Now that we've generated the functions, assign them into cls.
510509
for name, fn in zip(self.names, fns):
511510
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
511+
if annotations := self.method_annotations.get(name):
512+
fn.__annotate__ = self.make_annotate_function(annotations)
513+
512514
if self.unconditional_adds.get(name, False):
513515
setattr(cls, name, fn)
514516
else:
@@ -523,6 +525,34 @@ def add_fns_to_class(self, cls):
523525

524526
raise TypeError(error_msg)
525527

528+
@staticmethod
529+
def make_annotate_function(annotations):
530+
# Create an __annotate__ function for a dataclass
531+
# Try to return annotations in the same format as they would be
532+
# from a regular __init__ function
533+
def __annotate__(format):
534+
match format:
535+
case annotationlib.Format.VALUE | annotationlib.Format.FORWARDREF:
536+
return {
537+
k: v.evaluate(format=format)
538+
if isinstance(v, annotationlib.ForwardRef) else v
539+
for k, v in annotations.items()
540+
}
541+
case annotationlib.Format.STRING:
542+
string_annos = {}
543+
for k, v in annotations.items():
544+
if isinstance(v, str):
545+
string_annos[k] = v
546+
elif isinstance(v, annotationlib.ForwardRef):
547+
string_annos[k] = v.evaluate(format=annotationlib.Format.STRING)
548+
else:
549+
string_annos[k] = annotationlib.type_repr(v)
550+
return string_annos
551+
case _:
552+
raise NotImplementedError(format)
553+
554+
return __annotate__
555+
526556

527557
def _field_assign(frozen, name, value, self_name):
528558
# If we're a frozen class, then assign to our fields in __init__
@@ -612,7 +642,7 @@ def _init_param(f):
612642
elif f.default_factory is not MISSING:
613643
# There's a factory function. Set a marker.
614644
default = '=__dataclass_HAS_DEFAULT_FACTORY__'
615-
return f'{f.name}:__dataclass_type_{f.name}__{default}'
645+
return f'{f.name}{default}'
616646

617647

618648
def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
@@ -635,8 +665,10 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
635665
raise TypeError(f'non-default argument {f.name!r} '
636666
f'follows default argument {seen_default.name!r}')
637667

638-
locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
639-
**{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
668+
annotations = {f.name: f.type for f in fields if f.init}
669+
annotations["return"] = None
670+
671+
locals = {**{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
640672
'__dataclass_builtins_object__': object,
641673
}
642674
}
@@ -670,7 +702,7 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
670702
[self_name] + _init_params,
671703
body_lines,
672704
locals=locals,
673-
return_type=None)
705+
annotations=annotations)
674706

675707

676708
def _frozen_get_del_attr(cls, fields, func_builder):

0 commit comments

Comments
 (0)