Skip to content

Commit cd286d9

Browse files
committed
Fix specify-tags router filtering
1 parent e9f2965 commit cd286d9

5 files changed

Lines changed: 88 additions & 42 deletions

File tree

docs/cli-reference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ Input schema: `openapi/using_routers/using_routers_example.yaml`
119119

120120
### --specify-tags
121121

122-
Regenerate only the routers matching a comma-separated tag list.
122+
Generate or regenerate only the routers matching a comma-separated tag list.
123123

124124
`fastapi-codegen --input openapi/using_routers/using_routers_example.yaml --output app --template-dir modular_template --generate-routers --specify-tags Wild Boars, Fat Cats`
125125

docs/llms-full.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ Input schema: `openapi/using_routers/using_routers_example.yaml`
445445

446446
### --specify-tags
447447

448-
Regenerate only the routers matching a comma-separated tag list.
448+
Generate or regenerate only the routers matching a comma-separated tag list.
449449

450450
`fastapi-codegen --input openapi/using_routers/using_routers_example.yaml --output app --template-dir modular_template --generate-routers --specify-tags Wild Boars, Fat Cats`
451451

fastapi_code_generator/cli.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,28 @@ def generate_code(
279279
routers = sorted(
280280
[re.sub(TITLE_PATTERN, '_', tag.strip()).lower() for tag in sorted_tags]
281281
)
282-
template_vars = {**template_vars, "routers": routers, "tags": sorted_tags}
282+
router_tag_pairs = list(zip(routers, sorted_tags))
283+
specified_tags = set()
284+
existing_main_has_router_includes = False
285+
if generate_routers and specify_tags:
286+
specified_tags = {tag.strip() for tag in str(specify_tags).split(",")}
287+
main_path = output_dir / "main.py"
288+
if main_path.exists():
289+
existing_main_has_router_includes = (
290+
"app.include_router" in main_path.read_text(encoding=encoding)
291+
)
292+
293+
main_router_tag_pairs = router_tag_pairs
294+
if specified_tags and not existing_main_has_router_includes:
295+
main_router_tag_pairs = [
296+
(router, tag) for router, tag in router_tag_pairs if tag in specified_tags
297+
]
298+
299+
template_vars = {
300+
**template_vars,
301+
"routers": [router for router, _ in main_router_tag_pairs],
302+
"tags": [tag for _, tag in main_router_tag_pairs],
303+
}
283304

284305
for target in template_dir.rglob("*"):
285306
relative_path = target.relative_to(template_dir)
@@ -290,25 +311,20 @@ def generate_code(
290311
)
291312

292313
if generate_routers:
293-
tags = sorted_tags
294314
results.pop(Path("routers.jinja2"), None)
295-
if specify_tags:
296-
if Path(output_dir.joinpath("main.py")).exists():
297-
with open(Path(output_dir.joinpath("main.py")), 'r') as file:
298-
content = file.read()
299-
if "app.include_router" in content:
300-
tags = sorted(
301-
set(tag.strip() for tag in str(specify_tags).split(","))
302-
)
315+
router_pairs = router_tag_pairs
316+
if specified_tags and not existing_main_has_router_includes:
317+
router_pairs = main_router_tag_pairs
303318

304319
for target in template_dir.rglob("routers.*"):
305320
relative_path = target.relative_to(template_dir)
306-
for router, tag in zip(routers, sorted_tags):
321+
for router, tag in router_pairs:
307322
if (
308323
not Path(output_dir.joinpath("routers", router))
309324
.with_suffix(".py")
310325
.exists()
311-
or tag in tags
326+
or not specified_tags
327+
or tag in specified_tags
312328
):
313329
template_vars["tag"] = tag.strip()
314330
template = environment.get_template(str(relative_path))

fastapi_code_generator/prompt_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@
302302
},
303303
{
304304
'options': ['--specify-tags'],
305-
'description': 'Regenerate only the routers matching a '
305+
'description': 'Generate or regenerate only the routers matching a '
306306
'comma-separated tag list.',
307307
'cli_args': [
308308
'--input',

tests/main/test_main.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def test_generate_router_preserves_path_parameter_name(output_dir: Path) -> None
571571

572572
@pytest.mark.cli_doc(
573573
options=["--specify-tags"],
574-
option_description="Regenerate only the routers matching a comma-separated tag list.",
574+
option_description="Generate or regenerate only the routers matching a comma-separated tag list.",
575575
cli_args=[
576576
"--input",
577577
"openapi/using_routers/using_routers_example.yaml",
@@ -619,21 +619,37 @@ def test_generate_modify_specific_routers(oas_file: Path, output_dir: Path) -> N
619619

620620
@freeze_time("2023-04-11")
621621
def test_generate_specific_tags_without_existing_main(output_dir: Path) -> None:
622-
run_cli_and_assert(
623-
input_path=DATA_PATH
624-
/ OPEN_API_USING_ROUTERS_DIR_NAME
625-
/ "using_routers_example.yaml",
626-
output_path=output_dir,
627-
expected_path=EXPECTED_OPENAPI_PATH / "using_routers" / "using_routers_example",
628-
extra_args=[
629-
"--template-dir",
630-
str(BUILTIN_MODULAR_TEMPLATE_DIR),
631-
"--generate-routers",
632-
"--specify-tags",
633-
SPECIFIC_TAGS,
634-
],
622+
assert (
623+
run_main_with_args(
624+
[
625+
"--input",
626+
str(
627+
DATA_PATH
628+
/ OPEN_API_USING_ROUTERS_DIR_NAME
629+
/ "using_routers_example.yaml"
630+
),
631+
"--output",
632+
str(output_dir),
633+
"--template-dir",
634+
str(BUILTIN_MODULAR_TEMPLATE_DIR),
635+
"--generate-routers",
636+
"--specify-tags",
637+
SPECIFIC_TAGS,
638+
]
639+
)
640+
== 0
635641
)
636642

643+
main_text = output_dir.joinpath("main.py").read_text(encoding="utf-8")
644+
assert "from .routers import fat_cats, wild_boars" in main_text
645+
assert "app.include_router(fat_cats.router)" in main_text
646+
assert "app.include_router(wild_boars.router)" in main_text
647+
assert "slim_dogs" not in main_text
648+
assert output_dir.joinpath("routers", "fat_cats.py").exists()
649+
assert output_dir.joinpath("routers", "wild_boars.py").exists()
650+
assert not output_dir.joinpath("routers", "slim_dogs.py").exists()
651+
validate_generated_code(output_dir)
652+
637653

638654
@freeze_time("2023-04-11")
639655
def test_generate_specific_tags_with_existing_main_without_router_includes(
@@ -644,21 +660,35 @@ def test_generate_specific_tags_with_existing_main_without_router_includes(
644660
"from fastapi import FastAPI\n\napp = FastAPI()\n",
645661
encoding="utf-8",
646662
)
647-
run_cli_and_assert(
648-
input_path=DATA_PATH
649-
/ OPEN_API_USING_ROUTERS_DIR_NAME
650-
/ "using_routers_example.yaml",
651-
output_path=output_dir,
652-
expected_path=EXPECTED_OPENAPI_PATH / "using_routers" / "using_routers_example",
653-
extra_args=[
654-
"--template-dir",
655-
str(BUILTIN_MODULAR_TEMPLATE_DIR),
656-
"--generate-routers",
657-
"--specify-tags",
658-
SPECIFIC_TAGS,
659-
],
663+
assert (
664+
run_main_with_args(
665+
[
666+
"--input",
667+
str(
668+
DATA_PATH
669+
/ OPEN_API_USING_ROUTERS_DIR_NAME
670+
/ "using_routers_example.yaml"
671+
),
672+
"--output",
673+
str(output_dir),
674+
"--template-dir",
675+
str(BUILTIN_MODULAR_TEMPLATE_DIR),
676+
"--generate-routers",
677+
"--specify-tags",
678+
SPECIFIC_TAGS,
679+
]
680+
)
681+
== 0
660682
)
661683

684+
main_text = output_dir.joinpath("main.py").read_text(encoding="utf-8")
685+
assert "from .routers import fat_cats, wild_boars" in main_text
686+
assert "slim_dogs" not in main_text
687+
assert output_dir.joinpath("routers", "fat_cats.py").exists()
688+
assert output_dir.joinpath("routers", "wild_boars.py").exists()
689+
assert not output_dir.joinpath("routers", "slim_dogs.py").exists()
690+
validate_generated_code(output_dir)
691+
662692

663693
@freeze_time("2020-06-19")
664694
def test_generate_non_200_responses(output_dir: Path) -> None:

0 commit comments

Comments
 (0)