Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 63 additions & 4 deletions fastapi_code_generator/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def get_parameter_type(
if isinstance(content.schema_, ReferenceObject):
data_type = self.get_ref_data_type(content.schema_.ref)
ref_model = self.get_ref_model(content.schema_.ref)
schema = JsonSchemaObject.parse_obj(ref_model) # pragma: no cover
schema = JsonSchemaObject.model_validate(ref_model) # pragma: no cover
else:
schema = content.schema_
break
Expand Down Expand Up @@ -494,7 +494,23 @@ def parse_request_body(
request_body: RequestBodyObject,
path: List[str],
) -> Dict[str, DataType]:
request_body_fields = super().parse_request_body(name, request_body, path)
if 'multipart/form-data' in request_body.content:
content = {
media_type: media_obj
for media_type, media_obj in request_body.content.items()
if media_type != 'multipart/form-data'
}
request_body_fields = (
super().parse_request_body(
name,
request_body.model_copy(update={'content': content}),
path,
)
if content
else {}
)
else:
request_body_fields = super().parse_request_body(name, request_body, path)
arguments: List[Argument] = []
for (
media_type,
Expand Down Expand Up @@ -545,10 +561,11 @@ def parse_request_body(
Import.from_full_path("fastapi.Request")
)
elif media_type == 'multipart/form-data':
file_name, type_hint = self._get_upload_file_type(media_obj.schema_)
arguments.append(
Argument(
name='file', # type: ignore
type_hint='UploadFile', # type: ignore
name=file_name, # type: ignore
type_hint=type_hint, # type: ignore
required=True,
)
)
Expand All @@ -558,6 +575,48 @@ def parse_request_body(
self._temporary_operation['_request'] = arguments[0] if arguments else None
return request_body_fields

def _get_upload_file_type(
self, schema: Union[JsonSchemaObject, ReferenceObject]
) -> tuple[str, str]:
if isinstance(schema, ReferenceObject):
schema = JsonSchemaObject.model_validate(self.get_ref_model(schema.ref))
file_name = self._get_upload_file_name(schema)
if self._is_upload_file_array(schema):
self.imports_for_fastapi.append(Import(from_='typing', import_='List'))
return file_name, 'List[UploadFile]'
return file_name, 'UploadFile'

def _get_upload_file_name(self, schema: JsonSchemaObject) -> str:
if schema.properties:
# The operation template supports one multipart upload argument.
for property_name, property_schema in schema.properties.items():
if self._is_upload_file_array(
property_schema
) or self._is_upload_file_schema(property_schema):
return stringcase.snakecase(
self.model_resolver.get_valid_field_name(property_name)
)
return 'file'
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def _is_upload_file_array(self, schema: Any) -> bool:
if not isinstance(schema, JsonSchemaObject):
return False
if schema.is_array and self._is_upload_file_schema(schema.items):
return True
if schema.properties:
return any(
self._is_upload_file_array(property_schema)
for property_schema in schema.properties.values()
)
return False

def _is_upload_file_schema(self, schema: Any) -> bool:
return (
isinstance(schema, JsonSchemaObject)
and schema.type == 'string'
and schema.format == 'binary'
)

def parse_responses( # type: ignore[override]
self,
name: str,
Expand Down
44 changes: 44 additions & 0 deletions fastapi_code_generator/visitors/imports.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from pathlib import Path
from typing import Dict, Optional

Expand All @@ -8,6 +9,8 @@
from fastapi_code_generator.parser import OpenAPIParser
from fastapi_code_generator.visitor import Visitor

IDENTIFIER_PATTERN = re.compile(r'\b[A-Za-z_][A-Za-z0-9_]*\b')


def _get_most_of_reference(data_type: DataType) -> Optional[Reference]:
if data_type.reference:
Expand All @@ -19,6 +22,46 @@ def _get_most_of_reference(data_type: DataType) -> Optional[Reference]:
return None


def _collect_used_names(parser: OpenAPIParser) -> set[str]:
names: set[str] = set()
pending_operations = list(parser.operations.values())
while pending_operations:
operation = pending_operations.pop()
names.update(IDENTIFIER_PATTERN.findall(operation.arguments))
names.update(IDENTIFIER_PATTERN.findall(operation.return_type))
names.update(IDENTIFIER_PATTERN.findall(operation.response))
for models in operation.additional_responses.values():
for model in models.values():
names.update(IDENTIFIER_PATTERN.findall(model))
for callback_operations in operation.callbacks.values():
pending_operations.extend(callback_operations)
return names
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def _remove_unused_imports(imports: Imports, used_names: set[str]) -> None:
unused = [
(from_, import_)
for from_, imports_ in imports.items()
for import_ in imports_
if not {imports.get_effective_name(from_, import_), import_}.intersection(
used_names
)
]
reverse_lookup = {
(import_.from_, import_.import_): reference_path
for reference_path, import_ in imports.reference_paths.items()
}
for from_, import_ in unused:
import_obj = Import(
from_=from_,
import_=import_,
alias=imports.alias.get(from_, {}).get(import_),
reference_path=reverse_lookup.get((from_, import_)),
)
while imports.counter.get((from_, import_), 0) > 0:
imports.remove(import_obj)


def get_imports(parser: OpenAPIParser, model_path: Path) -> Dict[str, object]:
imports = Imports()

Expand All @@ -35,6 +78,7 @@ def get_imports(parser: OpenAPIParser, model_path: Path) -> Dict[str, object]:
for operation in parser.operations.values():
if operation.imports:
imports.alias.update(operation.imports.alias)
_remove_unused_imports(imports, _collect_used_names(parser))
return {'imports': imports}


Expand Down
2 changes: 1 addition & 1 deletion tests/data/expected/openapi/coverage/callbacks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from typing import Optional, Union
from typing import Optional

from fastapi import FastAPI

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from typing import Optional, Union
from typing import Optional

from fastapi import FastAPI

Expand Down
17 changes: 16 additions & 1 deletion tests/data/expected/openapi/default_template/upload/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from typing import Optional, Union
from typing import List, Optional

from fastapi import FastAPI, Request, UploadFile

Expand Down Expand Up @@ -45,3 +45,18 @@ def upload_pet_image_with_octet_stream(
Upload image with octet-stream for a pet
"""
pass


@app.post(
'/pets/{id}/images/form-data',
response_model=None,
responses={'default': {'model': Error}},
tags=['pets'],
)
def upload_pet_images_with_form_data(
id: str, images: List[UploadFile] = ...
) -> Optional[Error]:
"""
Upload images with Form-Data for a pet
"""
pass
33 changes: 33 additions & 0 deletions tests/data/openapi/default_template/upload.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,39 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/Error"
/pets/{id}/images/form-data:
post:
summary: Upload images with Form-Data for a pet
operationId: uploadPetImagesWithFormData
tags:
- pets
parameters:
- name: id
in: path
required: true
description: The id of the pet
schema:
type: string
requestBody:
content:
multipart/form-data:
schema:
type: object
properties:
images:
type: array
items:
type: string
format: binary
responses:
'201':
description: empty response
default:
description: unexpected error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
components:
schemas:
Error:
Expand Down
121 changes: 121 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from pathlib import Path

import pytest
from datamodel_code_generator.parser.jsonschema import JsonSchemaObject
from datamodel_code_generator.parser.openapi import ReferenceObject, RequestBodyObject

from fastapi_code_generator.parser import OpenAPIParser


def test_get_upload_file_type_resolves_reference(tmp_path: Path) -> None:
schema_path = tmp_path / "schema.yaml"
schema_path.write_text(
"""
openapi: 3.0.0
info:
title: Test
version: '1.0'
paths: {}
components:
schemas:
Uploads:
type: object
properties:
images:
type: array
items:
type: string
format: binary
""",
encoding="utf-8",
)
parser = OpenAPIParser(schema_path)
parser.parse_raw()

file_name, type_hint = parser._get_upload_file_type(
ReferenceObject.model_validate({"$ref": "#/components/schemas/Uploads"})
)

assert file_name == "images"
assert type_hint == "List[UploadFile]"
assert parser.imports_for_fastapi["typing"] == {"List"}


@pytest.mark.parametrize("value", [True, None, "string", 123, {}, []])
def test_is_upload_file_array_rejects_non_schema(value: object) -> None:
parser = OpenAPIParser(
"openapi: 3.0.0\ninfo: {title: Test, version: '1.0'}\npaths: {}\n"
)

assert parser._is_upload_file_array(value) is False


def test_get_upload_file_name_falls_back_when_properties_are_not_uploads() -> None:
parser = OpenAPIParser(
"openapi: 3.0.0\ninfo: {title: Test, version: '1.0'}\npaths: {}\n"
)
schema = JsonSchemaObject.model_validate(
{
"type": "object",
"properties": {
"description": {
"type": "string",
}
},
}
)

assert parser._get_upload_file_name(schema) == "file"


def test_get_upload_file_type_uses_single_binary_property_name() -> None:
parser = OpenAPIParser(
"openapi: 3.0.0\ninfo: {title: Test, version: '1.0'}\npaths: {}\n"
)
schema = JsonSchemaObject.model_validate(
{
"type": "object",
"properties": {
"avatar": {
"type": "string",
"format": "binary",
}
},
}
)

assert parser._get_upload_file_type(schema) == ("avatar", "UploadFile")


def test_parse_request_body_filters_multipart_from_mixed_content() -> None:
parser = OpenAPIParser(
"openapi: 3.0.0\ninfo: {title: Test, version: '1.0'}\npaths: {}\n"
)
request_body = RequestBodyObject.model_validate(
{
"content": {
"multipart/form-data": {
"schema": {
"type": "object",
"properties": {
"file": {
"type": "string",
"format": "binary",
}
},
}
},
"application/json": {
"schema": {
"type": "string",
}
},
}
}
)

request_body_fields = parser.parse_request_body(
"MixedUpload", request_body, ["paths", "mixed", "post", "requestBody"]
)

assert set(request_body_fields) == {"application/json"}
Loading