diff --git a/README.md b/README.md index 6319e3dd..a8899ba4 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,8 @@ Options: --specify-tags TEXT -c, --custom-visitor PATH --disable-timestamp + --include-request-argument Auto-inject a FastAPI Request parameter into + operations when not present. -d, --output-model-type [pydantic_v2.BaseModel|pydantic_v2.dataclass|dataclasses.dataclass|typing.TypedDict|msgspec.Struct] [default: pydantic_v2.BaseModel] -p, --python-version [3.10|3.11|3.12|3.13|3.14] diff --git a/docs/cli-reference.md b/docs/cli-reference.md index dd8d57e4..c007010d 100644 --- a/docs/cli-reference.md +++ b/docs/cli-reference.md @@ -18,6 +18,8 @@ Options: --specify-tags TEXT -c, --custom-visitor PATH --disable-timestamp + --include-request-argument Auto-inject a FastAPI Request parameter into + operations when not present. -d, --output-model-type [pydantic_v2.BaseModel|pydantic_v2.dataclass|dataclasses.dataclass|typing.TypedDict|msgspec.Struct] [default: pydantic_v2.BaseModel] -p, --python-version [3.10|3.11|3.12|3.13|3.14] @@ -72,6 +74,14 @@ Render generated files with a custom template directory. Input schema: `openapi/custom_template_security/custom_security.yaml` +### --include-request-argument + +Auto-inject a FastAPI Request argument in generated operation signatures when not present. + +`fastapi-codegen --input openapi/default_template/simple.yaml --output app --include-request-argument` + +Input schema: `openapi/default_template/simple.yaml` + ### --encoding Read the input schema using an explicit text encoding. diff --git a/docs/index.md b/docs/index.md index 051e94ab..5dc0e6b9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -57,6 +57,8 @@ Options: --specify-tags TEXT -c, --custom-visitor PATH --disable-timestamp + --include-request-argument Auto-inject a FastAPI Request parameter into + operations when not present. -d, --output-model-type [pydantic_v2.BaseModel|pydantic_v2.dataclass|dataclasses.dataclass|typing.TypedDict|msgspec.Struct] [default: pydantic_v2.BaseModel] -p, --python-version [3.10|3.11|3.12|3.13|3.14] diff --git a/docs/llms-full.txt b/docs/llms-full.txt index 8dcdbda0..8d6c6323 100644 --- a/docs/llms-full.txt +++ b/docs/llms-full.txt @@ -59,6 +59,8 @@ Options: --specify-tags TEXT -c, --custom-visitor PATH --disable-timestamp + --include-request-argument Auto-inject a FastAPI Request parameter into + operations when not present. -d, --output-model-type [pydantic_v2.BaseModel|pydantic_v2.dataclass|dataclasses.dataclass|typing.TypedDict|msgspec.Struct] [default: pydantic_v2.BaseModel] -p, --python-version [3.10|3.11|3.12|3.13|3.14] @@ -339,6 +341,8 @@ Options: --specify-tags TEXT -c, --custom-visitor PATH --disable-timestamp + --include-request-argument Auto-inject a FastAPI Request parameter into + operations when not present. -d, --output-model-type [pydantic_v2.BaseModel|pydantic_v2.dataclass|dataclasses.dataclass|typing.TypedDict|msgspec.Struct] [default: pydantic_v2.BaseModel] -p, --python-version [3.10|3.11|3.12|3.13|3.14] @@ -393,6 +397,14 @@ Render generated files with a custom template directory. Input schema: `openapi/custom_template_security/custom_security.yaml` +### --include-request-argument + +Auto-inject a FastAPI Request argument in generated operation signatures when not present. + +`fastapi-codegen --input openapi/default_template/simple.yaml --output app --include-request-argument` + +Input schema: `openapi/default_template/simple.yaml` + ### --encoding Read the input schema using an explicit text encoding. diff --git a/fastapi_code_generator/_types/generate_config_dict.py b/fastapi_code_generator/_types/generate_config_dict.py index 97576a32..4c355787 100644 --- a/fastapi_code_generator/_types/generate_config_dict.py +++ b/fastapi_code_generator/_types/generate_config_dict.py @@ -14,6 +14,7 @@ class GenerateConfigDict(TypedDict): encoding: NotRequired[str] enum_field_as_literal: NotRequired[Literal['all', 'one', 'none'] | None] generate_routers: NotRequired[bool] + include_request_argument: NotRequired[bool] input_file: str model_file: NotRequired[str | None] model_template_dir: NotRequired[str | None] diff --git a/fastapi_code_generator/cli.py b/fastapi_code_generator/cli.py index df237e6a..4bcd3bd1 100644 --- a/fastapi_code_generator/cli.py +++ b/fastapi_code_generator/cli.py @@ -90,6 +90,14 @@ def main( None, "--custom-visitor", "-c" ), disable_timestamp: bool = typer.Option(False, "--disable-timestamp"), + include_request_argument: bool = typer.Option( + False, + "--include-request-argument", + help=( + "Auto-inject a FastAPI Request parameter into operations when not " + "present." + ), + ), output_model_type: DataModelType = typer.Option( DataModelType.PydanticV2BaseModel.value, "--output-model-type", "-d" ), @@ -125,6 +133,7 @@ def main( enum_field_as_literal=enum_field_as_literal or None, custom_visitors=custom_visitors, disable_timestamp=disable_timestamp, + include_request_argument=include_request_argument, generate_routers=generate_routers, specify_tags=specify_tags, output_model_type=output_model_type, @@ -163,6 +172,7 @@ def generate_code( enum_field_as_literal: Optional[LiteralType] = None, custom_visitors: Optional[List[Path]] = None, disable_timestamp: bool = False, + include_request_argument: bool = False, generate_routers: Optional[bool] = None, specify_tags: Optional[str] = None, output_model_type: DataModelType = DataModelType.PydanticV2BaseModel, @@ -195,6 +205,7 @@ def generate_code( dump_resolve_reference_action=data_model_types.dump_resolve_reference_action, custom_template_dir=model_template_dir, target_python_version=python_version, + include_request_argument=include_request_argument, use_annotated=use_annotated, ) diff --git a/fastapi_code_generator/config.py b/fastapi_code_generator/config.py index cb296509..d827b505 100644 --- a/fastapi_code_generator/config.py +++ b/fastapi_code_generator/config.py @@ -150,6 +150,14 @@ class GenerateConfig(BaseModel): description="Omit timestamp headers from generated files.", json_schema_extra=cast(Any, _cli_metadata("--disable-timestamp")), ) + include_request_argument: bool = Field( + default=False, + description=( + "Auto-inject a FastAPI Request argument in generated operation " + "signatures when not present." + ), + json_schema_extra=cast(Any, _cli_metadata("--include-request-argument")), + ) output_model_type: OutputModelTypeName = Field( default="pydantic_v2.BaseModel", description="Model backend passed through to datamodel-code-generator.", diff --git a/fastapi_code_generator/parser.py b/fastapi_code_generator/parser.py index 102677a8..8967bb62 100644 --- a/fastapi_code_generator/parser.py +++ b/fastapi_code_generator/parser.py @@ -101,37 +101,35 @@ def __str__(self) -> str: # pragma: no cover return self.argument @property - def argument(self) -> str: # pragma: no cover + def resolved_type_hint(self) -> UsefulStr: if self.field is None: - type_hint = self.type_hint - else: - type_hint = ( - UsefulStr(self.field.type_hint) - if not isinstance(self.field, list) - else UsefulStr( - f"Union[{', '.join(field.type_hint for field in self.field)}]" - ) + return self.type_hint + return ( + UsefulStr(self.field.type_hint) + if not isinstance(self.field, list) + else UsefulStr( + f"Union[{', '.join(field.type_hint for field in self.field)}]" ) + ) + + @property + def argument(self) -> str: # pragma: no cover + type_hint = self.resolved_type_hint if self.default is None and self.required: return f'{self.name}: {type_hint}' return f'{self.name}: {type_hint} = {self.default}' @property def snakecase(self) -> str: - if self.field is None: - type_hint = self.type_hint - else: - type_hint = ( - UsefulStr(self.field.type_hint) - if not isinstance(self.field, list) - else UsefulStr( - f"Union[{', '.join(field.type_hint for field in self.field)}]" - ) - ) + type_hint = self.resolved_type_hint if self.default is None and self.required: return f'{stringcase.snakecase(self.name)}: {type_hint}' return f'{stringcase.snakecase(self.name)}: {type_hint} = {self.default}' + @property + def plain_parameter(self) -> str: + return self.snakecase + class Operation(CachedPropertyModel): method: UsefulStr @@ -191,20 +189,34 @@ def type(self) -> UsefulStr: """ return self.method + @cached_property + def _merged_arguments(self) -> List[Argument]: + return Operation.merge_arguments_with_union(self.arguments_list) + @property def arguments(self) -> str: # pragma: no cover - sorted_arguments = Operation.merge_arguments_with_union(self.arguments_list) - return ", ".join(argument.argument for argument in sorted_arguments) + return ", ".join(argument.argument for argument in self._merged_arguments) @property def snake_case_arguments(self) -> str: - sorted_arguments = Operation.merge_arguments_with_union(self.arguments_list) - return ", ".join(argument.snakecase for argument in sorted_arguments) + return ", ".join(argument.snakecase for argument in self._merged_arguments) + + @property + def plain_arguments(self) -> str: + return ", ".join( + stringcase.snakecase(argument.name) for argument in self._merged_arguments + ) + + @property + def plain_parameters(self) -> str: + return ", ".join( + argument.plain_parameter for argument in self._merged_arguments + ) @property def imports(self) -> Imports: imports = Imports() - for argument in Operation.merge_arguments_with_union(self.arguments_list): + for argument in self._merged_arguments: if isinstance(argument.field, list): for field in argument.field: imports.append(field.data_type.import_) @@ -274,6 +286,7 @@ def __init__( custom_class_name_generator: Optional[Callable[[str], str]] = None, field_extra_keys: Optional[Set[str]] = None, field_include_all_keys: bool = False, + include_request_argument: bool = False, use_annotated: bool = False, ): super().__init__( @@ -320,6 +333,7 @@ def __init__( self._temporary_operation: Dict[str, Any] = {} self.imports_for_fastapi: Imports = Imports() self.data_types: List[DataType] = [] + self.include_request_argument = include_request_argument def parse_info(self) -> Optional[Dict[str, Any]]: if not isinstance(self.raw_obj, dict): # pragma: no cover @@ -442,6 +456,19 @@ def get_argument_list( if request: arguments.append(request) + if self.include_request_argument and not any( + argument.name == "request" for argument in arguments + ): + arguments.insert( + 0, + Argument( + name='request', # type: ignore + type_hint='Request', # type: ignore + required=True, + ), + ) + self.imports_for_fastapi.append(Import.from_full_path("fastapi.Request")) + positional_argument: bool = False for argument in arguments: if positional_argument and argument.required and argument.default is None: @@ -502,7 +529,7 @@ def parse_request_body( ) ) self.imports_for_fastapi.append( - Import.from_full_path('starlette.requests.Request') + Import.from_full_path('fastapi.Request') ) elif media_type == 'application/octet-stream': arguments.append( diff --git a/fastapi_code_generator/prompt_data.py b/fastapi_code_generator/prompt_data.py index bca47ce0..e0ff260f 100644 --- a/fastapi_code_generator/prompt_data.py +++ b/fastapi_code_generator/prompt_data.py @@ -121,6 +121,17 @@ 'type': 'boolean', 'choices': [], }, + { + 'name': 'include_request_argument', + 'cli_flags': ['--include-request-argument'], + 'description': 'Auto-inject a FastAPI Request argument in ' + 'generated operation signatures when not present.', + 'required': False, + 'default': False, + 'multiple': False, + 'type': 'boolean', + 'choices': [], + }, { 'name': 'output_model_type', 'cli_flags': ['--output-model-type', '-d'], @@ -212,6 +223,19 @@ ], 'input_schema': 'openapi/custom_template_security/custom_security.yaml', }, + { + 'options': ['--include-request-argument'], + 'description': 'Auto-inject a FastAPI Request argument in generated ' + 'operation signatures when not present.', + 'cli_args': [ + '--input', + 'openapi/default_template/simple.yaml', + '--output', + 'app', + '--include-request-argument', + ], + 'input_schema': 'openapi/default_template/simple.yaml', + }, { 'options': ['--encoding'], 'description': 'Read the input schema using an explicit text ' 'encoding.', diff --git a/tests/data/expected/openapi/coverage/model_options/main.py b/tests/data/expected/openapi/coverage/model_options/main.py index f26f56a7..b909bfaa 100644 --- a/tests/data/expected/openapi/coverage/model_options/main.py +++ b/tests/data/expected/openapi/coverage/model_options/main.py @@ -7,7 +7,6 @@ from typing import List, Optional, Union from fastapi import FastAPI, Path, Query, Request -from starlette.requests import Request from .custom_models import ( Error, diff --git a/tests/data/expected/openapi/default_template/body_and_parameters/main.py b/tests/data/expected/openapi/default_template/body_and_parameters/main.py index f8075261..1b63a496 100644 --- a/tests/data/expected/openapi/default_template/body_and_parameters/main.py +++ b/tests/data/expected/openapi/default_template/body_and_parameters/main.py @@ -7,7 +7,6 @@ from typing import List, Optional, Union from fastapi import FastAPI, Path, Query, Request -from starlette.requests import Request from .models import ( Error, diff --git a/tests/main/test_main.py b/tests/main/test_main.py index eb04ef35..73b67f03 100644 --- a/tests/main/test_main.py +++ b/tests/main/test_main.py @@ -262,6 +262,97 @@ def test_generate_escapes_aliases_in_parameter_defaults(output_dir: Path) -> Non validate_generated_code(output_dir) +@freeze_time("2020-06-19") +def test_custom_template_can_use_plain_arguments( + tmp_path: Path, output_dir: Path +) -> None: + template_dir = tmp_path / "template" + template_dir.mkdir() + template_dir.joinpath("main.jinja2").write_text( + """ +PLAIN_ARGUMENTS = "{{ operations[0].plain_arguments }}" +PLAIN_PARAMETERS = "{{ operations[0].plain_parameters }}" +LEGACY_ARGUMENTS = "{{ operations[0].snake_case_arguments }}" +""", + encoding="utf-8", + ) + spec = """openapi: 3.0.0 +info: + title: Plain arguments + version: 1.0.0 +paths: + /pets/{pet_id}: + get: + operationId: listPets + responses: + '200': + description: ok + parameters: + - name: pet_id + in: path + required: true + schema: + type: string + - name: limit + in: query + required: false + schema: + type: integer + default: 0 +""" + + generate_code( + "plain_arguments.yaml", + spec, + "utf-8", + output_dir, + template_dir, + ) + + generated = output_dir.joinpath("main.py").read_text(encoding="utf-8") + assert 'PLAIN_ARGUMENTS = "pet_id, limit"' in generated + assert 'PLAIN_PARAMETERS = "pet_id: str, limit: Optional[int] = 0"' in generated + assert 'LEGACY_ARGUMENTS = "pet_id: str, limit: Optional[int] = 0"' in generated + validate_generated_code(output_dir) + + +@pytest.mark.cli_doc( + options=["--include-request-argument"], + option_description=( + "Auto-inject a FastAPI Request argument in generated operation signatures " + "when not present." + ), + cli_args=[ + "--input", + "openapi/default_template/simple.yaml", + "--output", + "app", + "--include-request-argument", + ], + input_schema="openapi/default_template/simple.yaml", +) +@freeze_time("2020-06-19") +def test_include_request_argument(output_dir: Path) -> None: + assert ( + run_main_with_args( + [ + "--input", + str(DATA_PATH / OPEN_API_DEFAULT_TEMPLATE_DIR_NAME / "simple.yaml"), + "--output", + str(output_dir), + "--include-request-argument", + ] + ) + == 0 + ) + + generated = output_dir.joinpath("main.py").read_text(encoding="utf-8") + assert "Request" in generated + assert "def list_pets(" in generated + assert "request: Request" in generated + validate_generated_code(output_dir) + + @pytest.mark.cli_doc( options=["--encoding"], option_description="Read the input schema using an explicit text encoding.", diff --git a/tests/test_config.py b/tests/test_config.py index 3e1f657f..f2d607dd 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -31,6 +31,7 @@ def test_generate_config_defaults() -> None: assert config.python_version == "3.10" assert config.custom_visitors is None assert config.generate_routers is False + assert config.include_request_argument is False def test_validate_generate_config_model_matches_cli() -> None: