From a00be0ee0ab8d1bd6e5bc2264bee129debfe0fe7 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Mon, 27 Apr 2026 17:50:48 +0900 Subject: [PATCH] Fix multiple file upload generation --- fastapi_code_generator/parser.py | 67 +++++++++- fastapi_code_generator/visitors/imports.py | 44 +++++++ .../openapi/coverage/callbacks/main.py | 2 +- .../callbacks_with_operation_id/main.py | 2 +- .../openapi/default_template/upload/main.py | 17 ++- .../data/openapi/default_template/upload.yaml | 33 +++++ tests/test_parser.py | 121 ++++++++++++++++++ 7 files changed, 279 insertions(+), 7 deletions(-) create mode 100644 tests/test_parser.py diff --git a/fastapi_code_generator/parser.py b/fastapi_code_generator/parser.py index 3db4d0cf..7f1303ab 100644 --- a/fastapi_code_generator/parser.py +++ b/fastapi_code_generator/parser.py @@ -381,7 +381,7 @@ def get_parameter_type( if isinstance(content.schema_, ReferenceObject): data_type = self.get_ref_data_type(content.schema_.ref) ref_model = self.get_ref_model(content.schema_.ref) - schema = JsonSchemaObject.parse_obj(ref_model) # pragma: no cover + schema = JsonSchemaObject.model_validate(ref_model) # pragma: no cover else: schema = content.schema_ break @@ -494,7 +494,23 @@ def parse_request_body( request_body: RequestBodyObject, path: List[str], ) -> Dict[str, DataType]: - request_body_fields = super().parse_request_body(name, request_body, path) + if 'multipart/form-data' in request_body.content: + content = { + media_type: media_obj + for media_type, media_obj in request_body.content.items() + if media_type != 'multipart/form-data' + } + request_body_fields = ( + super().parse_request_body( + name, + request_body.model_copy(update={'content': content}), + path, + ) + if content + else {} + ) + else: + request_body_fields = super().parse_request_body(name, request_body, path) arguments: List[Argument] = [] for ( media_type, @@ -545,10 +561,11 @@ def parse_request_body( Import.from_full_path("fastapi.Request") ) elif media_type == 'multipart/form-data': + file_name, type_hint = self._get_upload_file_type(media_obj.schema_) arguments.append( Argument( - name='file', # type: ignore - type_hint='UploadFile', # type: ignore + name=file_name, # type: ignore + type_hint=type_hint, # type: ignore required=True, ) ) @@ -558,6 +575,48 @@ def parse_request_body( self._temporary_operation['_request'] = arguments[0] if arguments else None return request_body_fields + def _get_upload_file_type( + self, schema: Union[JsonSchemaObject, ReferenceObject] + ) -> tuple[str, str]: + if isinstance(schema, ReferenceObject): + schema = JsonSchemaObject.model_validate(self.get_ref_model(schema.ref)) + file_name = self._get_upload_file_name(schema) + if self._is_upload_file_array(schema): + self.imports_for_fastapi.append(Import(from_='typing', import_='List')) + return file_name, 'List[UploadFile]' + return file_name, 'UploadFile' + + def _get_upload_file_name(self, schema: JsonSchemaObject) -> str: + if schema.properties: + # The operation template supports one multipart upload argument. + for property_name, property_schema in schema.properties.items(): + if self._is_upload_file_array( + property_schema + ) or self._is_upload_file_schema(property_schema): + return stringcase.snakecase( + self.model_resolver.get_valid_field_name(property_name) + ) + return 'file' + + def _is_upload_file_array(self, schema: Any) -> bool: + if not isinstance(schema, JsonSchemaObject): + return False + if schema.is_array and self._is_upload_file_schema(schema.items): + return True + if schema.properties: + return any( + self._is_upload_file_array(property_schema) + for property_schema in schema.properties.values() + ) + return False + + def _is_upload_file_schema(self, schema: Any) -> bool: + return ( + isinstance(schema, JsonSchemaObject) + and schema.type == 'string' + and schema.format == 'binary' + ) + def parse_responses( # type: ignore[override] self, name: str, diff --git a/fastapi_code_generator/visitors/imports.py b/fastapi_code_generator/visitors/imports.py index b1b37be6..d186c18f 100644 --- a/fastapi_code_generator/visitors/imports.py +++ b/fastapi_code_generator/visitors/imports.py @@ -1,3 +1,4 @@ +import re from pathlib import Path from typing import Dict, Optional @@ -8,6 +9,8 @@ from fastapi_code_generator.parser import OpenAPIParser from fastapi_code_generator.visitor import Visitor +IDENTIFIER_PATTERN = re.compile(r'\b[A-Za-z_][A-Za-z0-9_]*\b') + def _get_most_of_reference(data_type: DataType) -> Optional[Reference]: if data_type.reference: @@ -19,6 +22,46 @@ def _get_most_of_reference(data_type: DataType) -> Optional[Reference]: return None +def _collect_used_names(parser: OpenAPIParser) -> set[str]: + names: set[str] = set() + pending_operations = list(parser.operations.values()) + while pending_operations: + operation = pending_operations.pop() + names.update(IDENTIFIER_PATTERN.findall(operation.arguments)) + names.update(IDENTIFIER_PATTERN.findall(operation.return_type)) + names.update(IDENTIFIER_PATTERN.findall(operation.response)) + for models in operation.additional_responses.values(): + for model in models.values(): + names.update(IDENTIFIER_PATTERN.findall(model)) + for callback_operations in operation.callbacks.values(): + pending_operations.extend(callback_operations) + return names + + +def _remove_unused_imports(imports: Imports, used_names: set[str]) -> None: + unused = [ + (from_, import_) + for from_, imports_ in imports.items() + for import_ in imports_ + if not {imports.get_effective_name(from_, import_), import_}.intersection( + used_names + ) + ] + reverse_lookup = { + (import_.from_, import_.import_): reference_path + for reference_path, import_ in imports.reference_paths.items() + } + for from_, import_ in unused: + import_obj = Import( + from_=from_, + import_=import_, + alias=imports.alias.get(from_, {}).get(import_), + reference_path=reverse_lookup.get((from_, import_)), + ) + while imports.counter.get((from_, import_), 0) > 0: + imports.remove(import_obj) + + def get_imports(parser: OpenAPIParser, model_path: Path) -> Dict[str, object]: imports = Imports() @@ -35,6 +78,7 @@ def get_imports(parser: OpenAPIParser, model_path: Path) -> Dict[str, object]: for operation in parser.operations.values(): if operation.imports: imports.alias.update(operation.imports.alias) + _remove_unused_imports(imports, _collect_used_names(parser)) return {'imports': imports} diff --git a/tests/data/expected/openapi/coverage/callbacks/main.py b/tests/data/expected/openapi/coverage/callbacks/main.py index ab190bb8..f2b7edc3 100644 --- a/tests/data/expected/openapi/coverage/callbacks/main.py +++ b/tests/data/expected/openapi/coverage/callbacks/main.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Optional, Union +from typing import Optional from fastapi import FastAPI diff --git a/tests/data/expected/openapi/coverage/callbacks_with_operation_id/main.py b/tests/data/expected/openapi/coverage/callbacks_with_operation_id/main.py index e184b855..9365b202 100644 --- a/tests/data/expected/openapi/coverage/callbacks_with_operation_id/main.py +++ b/tests/data/expected/openapi/coverage/callbacks_with_operation_id/main.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Optional, Union +from typing import Optional from fastapi import FastAPI diff --git a/tests/data/expected/openapi/default_template/upload/main.py b/tests/data/expected/openapi/default_template/upload/main.py index e7e8fc77..cc3073cd 100644 --- a/tests/data/expected/openapi/default_template/upload/main.py +++ b/tests/data/expected/openapi/default_template/upload/main.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Optional, Union +from typing import List, Optional from fastapi import FastAPI, Request, UploadFile @@ -45,3 +45,18 @@ def upload_pet_image_with_octet_stream( Upload image with octet-stream for a pet """ pass + + +@app.post( + '/pets/{id}/images/form-data', + response_model=None, + responses={'default': {'model': Error}}, + tags=['pets'], +) +def upload_pet_images_with_form_data( + id: str, images: List[UploadFile] = ... +) -> Optional[Error]: + """ + Upload images with Form-Data for a pet + """ + pass diff --git a/tests/data/openapi/default_template/upload.yaml b/tests/data/openapi/default_template/upload.yaml index 6e591afe..7acde17d 100644 --- a/tests/data/openapi/default_template/upload.yaml +++ b/tests/data/openapi/default_template/upload.yaml @@ -65,6 +65,39 @@ paths: application/json: schema: $ref: "#/components/schemas/Error" + /pets/{id}/images/form-data: + post: + summary: Upload images with Form-Data for a pet + operationId: uploadPetImagesWithFormData + tags: + - pets + parameters: + - name: id + in: path + required: true + description: The id of the pet + schema: + type: string + requestBody: + content: + multipart/form-data: + schema: + type: object + properties: + images: + type: array + items: + type: string + format: binary + responses: + '201': + description: empty response + default: + description: unexpected error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" components: schemas: Error: diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 00000000..fc5a6868 --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,121 @@ +from pathlib import Path + +import pytest +from datamodel_code_generator.parser.jsonschema import JsonSchemaObject +from datamodel_code_generator.parser.openapi import ReferenceObject, RequestBodyObject + +from fastapi_code_generator.parser import OpenAPIParser + + +def test_get_upload_file_type_resolves_reference(tmp_path: Path) -> None: + schema_path = tmp_path / "schema.yaml" + schema_path.write_text( + """ +openapi: 3.0.0 +info: + title: Test + version: '1.0' +paths: {} +components: + schemas: + Uploads: + type: object + properties: + images: + type: array + items: + type: string + format: binary +""", + encoding="utf-8", + ) + parser = OpenAPIParser(schema_path) + parser.parse_raw() + + file_name, type_hint = parser._get_upload_file_type( + ReferenceObject.model_validate({"$ref": "#/components/schemas/Uploads"}) + ) + + assert file_name == "images" + assert type_hint == "List[UploadFile]" + assert parser.imports_for_fastapi["typing"] == {"List"} + + +@pytest.mark.parametrize("value", [True, None, "string", 123, {}, []]) +def test_is_upload_file_array_rejects_non_schema(value: object) -> None: + parser = OpenAPIParser( + "openapi: 3.0.0\ninfo: {title: Test, version: '1.0'}\npaths: {}\n" + ) + + assert parser._is_upload_file_array(value) is False + + +def test_get_upload_file_name_falls_back_when_properties_are_not_uploads() -> None: + parser = OpenAPIParser( + "openapi: 3.0.0\ninfo: {title: Test, version: '1.0'}\npaths: {}\n" + ) + schema = JsonSchemaObject.model_validate( + { + "type": "object", + "properties": { + "description": { + "type": "string", + } + }, + } + ) + + assert parser._get_upload_file_name(schema) == "file" + + +def test_get_upload_file_type_uses_single_binary_property_name() -> None: + parser = OpenAPIParser( + "openapi: 3.0.0\ninfo: {title: Test, version: '1.0'}\npaths: {}\n" + ) + schema = JsonSchemaObject.model_validate( + { + "type": "object", + "properties": { + "avatar": { + "type": "string", + "format": "binary", + } + }, + } + ) + + assert parser._get_upload_file_type(schema) == ("avatar", "UploadFile") + + +def test_parse_request_body_filters_multipart_from_mixed_content() -> None: + parser = OpenAPIParser( + "openapi: 3.0.0\ninfo: {title: Test, version: '1.0'}\npaths: {}\n" + ) + request_body = RequestBodyObject.model_validate( + { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + "file": { + "type": "string", + "format": "binary", + } + }, + } + }, + "application/json": { + "schema": { + "type": "string", + } + }, + } + } + ) + + request_body_fields = parser.parse_request_body( + "MixedUpload", request_body, ["paths", "mixed", "post", "requestBody"] + ) + + assert set(request_body_fields) == {"application/json"}