Skip to content

Commit 38b5c9a

Browse files
committed
speed up unit test suite
1 parent 26d3966 commit 38b5c9a

7 files changed

Lines changed: 73 additions & 71 deletions

File tree

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ jobs:
8989
- name: Run the automated tests
9090
run: |
9191
python --version
92-
uv run --frozen pytest -n auto --dist loadfile -s -v --log-cli-level=INFO tests/unit
92+
uv run --frozen pytest -n auto --dist worksteal -s -v --log-cli-level=INFO tests/unit
9393
9494
e2e-tests:
9595
runs-on: ubuntu-latest

langfuse/_utils/prompt_cache.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import atexit
44
import os
55
from datetime import datetime
6-
from queue import Empty, Queue
6+
from queue import Queue
77
from threading import Thread
88
from typing import Callable, Dict, List, Optional, Set
99

@@ -18,6 +18,7 @@
1818
)
1919

2020
DEFAULT_PROMPT_CACHE_REFRESH_WORKERS = 1
21+
_SHUTDOWN_SENTINEL = object()
2122

2223

2324
class PromptCacheItem:
@@ -46,26 +47,29 @@ def __init__(self, queue: Queue, identifier: int):
4647

4748
def run(self) -> None:
4849
while self.running:
50+
task = self._queue.get()
51+
52+
if task is _SHUTDOWN_SENTINEL:
53+
self._queue.task_done()
54+
continue
55+
56+
logger.debug(
57+
f"PromptCacheRefreshConsumer processing task, {self._identifier}"
58+
)
4959
try:
50-
task = self._queue.get(timeout=1)
51-
logger.debug(
52-
f"PromptCacheRefreshConsumer processing task, {self._identifier}"
60+
task()
61+
# Task failed, but we still consider it processed
62+
except Exception as e:
63+
logger.warning(
64+
f"PromptCacheRefreshConsumer encountered an error, cache was not refreshed: {self._identifier}, {e}"
5365
)
54-
try:
55-
task()
56-
# Task failed, but we still consider it processed
57-
except Exception as e:
58-
logger.warning(
59-
f"PromptCacheRefreshConsumer encountered an error, cache was not refreshed: {self._identifier}, {e}"
60-
)
6166

62-
self._queue.task_done()
63-
except Empty:
64-
pass
67+
self._queue.task_done()
6568

6669
def pause(self) -> None:
6770
"""Pause the consumer."""
6871
self.running = False
72+
self._queue.put(_SHUTDOWN_SENTINEL)
6973

7074

7175
class PromptCacheTaskManager(object):
@@ -99,6 +103,9 @@ def add_task(self, key: str, task: Callable[[], None]) -> None:
99103
def active_tasks(self) -> int:
100104
return len(self._processing_keys)
101105

106+
def wait_for_idle(self) -> None:
107+
self._queue.join()
108+
102109
def _wrap_task(self, key: str, task: Callable[[], None]) -> Callable[[], None]:
103110
def wrapped() -> None:
104111
logger.debug(f"Refreshing prompt cache for key: {key}")

tests/unit/test_otel.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,17 @@ class TestOTelBase:
5454
@pytest.fixture(scope="function", autouse=True)
5555
def cleanup_otel(self):
5656
"""Reset OpenTelemetry state between tests."""
57-
original_provider = trace_api.get_tracer_provider()
57+
from opentelemetry.util._once import Once
58+
59+
trace_api._TRACER_PROVIDER = None
60+
trace_api._PROXY_TRACER_PROVIDER = trace_api.ProxyTracerProvider()
61+
trace_api._TRACER_PROVIDER_SET_ONCE = Once()
62+
5863
yield
59-
trace_api.set_tracer_provider(original_provider)
6064
LangfuseResourceManager.reset()
65+
trace_api._TRACER_PROVIDER = None
66+
trace_api._PROXY_TRACER_PROVIDER = trace_api.ProxyTracerProvider()
67+
trace_api._TRACER_PROVIDER_SET_ONCE = Once()
6168

6269
@pytest.fixture
6370
def memory_exporter(self):
@@ -97,7 +104,7 @@ def mock_init(self, **kwargs):
97104
self,
98105
span_exporter=memory_exporter,
99106
max_export_batch_size=512,
100-
schedule_delay_millis=5000,
107+
schedule_delay_millis=1,
101108
)
102109

103110
monkeypatch.setattr(
@@ -1870,7 +1877,7 @@ def update_random_metadata(thread_id):
18701877
update = random.choice(updates)
18711878

18721879
# Sleep a tiny bit to simulate work and increase chances of thread interleaving
1873-
time.sleep(random.uniform(0.001, 0.01))
1880+
time.sleep(random.uniform(0.0005, 0.001))
18741881

18751882
# Apply the update to current_metadata (in a real system, this would update OTEL span)
18761883
with metadata_lock:
@@ -2001,7 +2008,7 @@ def mock_processor_init(self, **kwargs):
20012008
self,
20022009
span_exporter=exporter,
20032010
max_export_batch_size=512,
2004-
schedule_delay_millis=5000,
2011+
schedule_delay_millis=1,
20052012
)
20062013

20072014
monkeypatch.setattr(
@@ -2118,7 +2125,7 @@ def create_spans_project1():
21182125
metadata={"project": "project1", "index": i},
21192126
)
21202127
# Small sleep to ensure overlap with other thread
2121-
time.sleep(0.01)
2128+
time.sleep(0.001)
21222129
span.end()
21232130

21242131
def create_spans_project2():
@@ -2128,7 +2135,7 @@ def create_spans_project2():
21282135
metadata={"project": "project2", "index": i},
21292136
)
21302137
# Small sleep to ensure overlap with other thread
2131-
time.sleep(0.01)
2138+
time.sleep(0.001)
21322139
span.end()
21332140

21342141
# Start threads
@@ -2378,7 +2385,7 @@ def mock_processor_init(self, **kwargs):
23782385
self,
23792386
span_exporter=exporter,
23802387
max_export_batch_size=512,
2381-
schedule_delay_millis=5000,
2388+
schedule_delay_millis=1,
23822389
)
23832390

23842391
monkeypatch.setattr(
@@ -2757,7 +2764,7 @@ async def async_task(parent_span, task_id):
27572764
child_span = parent_span.start_observation(name=f"async-task-{task_id}")
27582765

27592766
# Simulate async work
2760-
await asyncio.sleep(0.1)
2767+
await asyncio.sleep(0.01)
27612768

27622769
# Update span with results
27632770
child_span.update(
@@ -2948,7 +2955,7 @@ async def test_span_metadata_updates_in_async_context(
29482955

29492956
# Define async tasks that update different parts of metadata
29502957
async def update_temperature():
2951-
await asyncio.sleep(0.1) # Simulate some async work
2958+
await asyncio.sleep(0.01) # Simulate some async work
29522959
main_span.update(
29532960
metadata={
29542961
"llm_config": {
@@ -2960,7 +2967,7 @@ async def update_temperature():
29602967
)
29612968

29622969
async def update_model():
2963-
await asyncio.sleep(0.05) # Simulate some async work
2970+
await asyncio.sleep(0.005) # Simulate some async work
29642971
main_span.update(
29652972
metadata={
29662973
"llm_config": {
@@ -2970,7 +2977,7 @@ async def update_model():
29702977
)
29712978

29722979
async def add_context_length():
2973-
await asyncio.sleep(0.15) # Simulate some async work
2980+
await asyncio.sleep(0.015) # Simulate some async work
29742981
main_span.update(
29752982
metadata={
29762983
"llm_config": {
@@ -2982,7 +2989,7 @@ async def add_context_length():
29822989
)
29832990

29842991
async def update_user_id():
2985-
await asyncio.sleep(0.08) # Simulate some async work
2992+
await asyncio.sleep(0.008) # Simulate some async work
29862993
main_span.update(
29872994
metadata={
29882995
"request_info": {
@@ -3047,7 +3054,7 @@ def test_metrics_and_timing(self, langfuse_client, memory_exporter):
30473054
span = langfuse_client.start_observation(name="timing-test-span")
30483055

30493056
# Add a small delay
3050-
time.sleep(0.1)
3057+
time.sleep(0.01)
30513058

30523059
# End the span
30533060
span.end()
@@ -3089,10 +3096,10 @@ def test_metrics_and_timing(self, langfuse_client, memory_exporter):
30893096
) / 1_000_000_000
30903097
assert span_duration_seconds > 0, "Span duration should be positive"
30913098

3092-
# Since we slept for 0.1 seconds, the span duration should be at least 0.05 seconds
3099+
# Since we slept for 0.01 seconds, the span duration should be at least 0.005 seconds
30933100
# but we'll be generous with the upper bound due to potential system delays
3094-
assert span_duration_seconds >= 0.05, (
3095-
f"Span duration ({span_duration_seconds}s) should be at least 0.05s"
3101+
assert span_duration_seconds >= 0.005, (
3102+
f"Span duration ({span_duration_seconds}s) should be at least 0.005s"
30963103
)
30973104

30983105

@@ -3349,6 +3356,7 @@ def langfuse_client(self, monkeypatch):
33493356
public_key="test-public-key",
33503357
secret_key="test-secret-key",
33513358
base_url="http://test-host",
3359+
tracing_enabled=False,
33523360
)
33533361

33543362
return client

tests/unit/test_prompt.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from time import sleep
21
from unittest.mock import Mock, patch
32

43
import pytest
@@ -139,6 +138,10 @@ def langfuse():
139138
return langfuse_instance
140139

141140

141+
def wait_for_prompt_refresh(langfuse: Langfuse) -> None:
142+
langfuse._resources.prompt_cache._task_manager.wait_for_idle()
143+
144+
142145
def test_get_fresh_prompt(langfuse):
143146
prompt_name = "test_get_fresh_prompt"
144147
prompt = Prompt_Text(
@@ -376,10 +379,7 @@ def test_get_fresh_prompt_when_expired_cache_custom_ttl(mock_time, langfuse: Lan
376379

377380
result_call_3 = langfuse.get_prompt(prompt_name)
378381

379-
while True:
380-
if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0:
381-
break
382-
sleep(0.1)
382+
wait_for_prompt_refresh(langfuse)
383383

384384
assert mock_server_call.call_count == 2
385385
assert result_call_3 == prompt_client
@@ -483,10 +483,7 @@ def test_get_stale_prompt_when_expired_cache_default_ttl(mock_time, langfuse: La
483483
langfuse.get_prompt(prompt_name)
484484
langfuse.get_prompt(prompt_name)
485485

486-
while True:
487-
if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0:
488-
break
489-
sleep(0.1)
486+
wait_for_prompt_refresh(langfuse)
490487

491488
assert mock_server_call.call_count == 2
492489

@@ -527,10 +524,7 @@ def test_get_fresh_prompt_when_expired_cache_default_ttl(mock_time, langfuse: La
527524
mock_time.return_value = DEFAULT_PROMPT_CACHE_TTL_SECONDS + 1
528525

529526
result_call_3 = langfuse.get_prompt(prompt_name)
530-
while True:
531-
if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0:
532-
break
533-
sleep(0.1)
527+
wait_for_prompt_refresh(langfuse)
534528

535529
assert mock_server_call.call_count == 2
536530
assert result_call_3 == prompt_client
@@ -563,10 +557,7 @@ def test_get_expired_prompt_when_failing_fetch(mock_time, langfuse: Langfuse):
563557
mock_server_call.side_effect = Exception("Server error")
564558

565559
result_call_2 = langfuse.get_prompt(prompt_name, max_retries=1)
566-
while True:
567-
if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0:
568-
break
569-
sleep(0.1)
560+
wait_for_prompt_refresh(langfuse)
570561

571562
assert mock_server_call.call_count == 3
572563
assert result_call_2 == prompt_client
@@ -619,10 +610,7 @@ def raise_not_found(*_args: object, **_kwargs: object) -> None:
619610
)
620611
assert stale_result == prompt_client
621612

622-
while True:
623-
if langfuse._resources.prompt_cache._task_manager.active_tasks() == 0:
624-
break
625-
sleep(0.1)
613+
wait_for_prompt_refresh(langfuse)
626614

627615
assert langfuse._resources.prompt_cache.get(cache_key) is None
628616

tests/unit/test_prompt_atexit.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ def test_prompts_atexit():
2020
print("Adding prompt cache", PromptCache)
2121
prompt_cache = PromptCache(max_prompt_refresh_workers=10)
2222
23-
# example task that takes 2 seconds but we will force it to exit earlier
24-
def wait_2_sec():
25-
time.sleep(2)
23+
# example task that stays in flight briefly while the process exits
24+
def wait_briefly():
25+
time.sleep(0.1)
2626
2727
# 8 times
2828
for i in range(8):
29-
prompt_cache.add_refresh_prompt_task(f"key_wait_2_sec_i_{i}", lambda: wait_2_sec())
29+
prompt_cache.add_refresh_prompt_task(
30+
f"key_wait_briefly_i_{i}", lambda: wait_briefly()
31+
)
3032
"""
3133

3234
process = subprocess.Popen(
@@ -74,12 +76,14 @@ async def main():
7476
print("Adding prompt cache", PromptCache)
7577
prompt_cache = PromptCache(max_prompt_refresh_workers=10)
7678
77-
# example task that takes 2 seconds but we will force it to exit earlier
78-
def wait_2_sec():
79-
time.sleep(2)
79+
# example task that stays in flight briefly while the process exits
80+
def wait_briefly():
81+
time.sleep(0.1)
8082
8183
async def add_new_prompt_refresh(i: int):
82-
prompt_cache.add_refresh_prompt_task(f"key_wait_2_sec_i_{i}", lambda: wait_2_sec())
84+
prompt_cache.add_refresh_prompt_task(
85+
f"key_wait_briefly_i_{i}", lambda: wait_briefly()
86+
)
8387
8488
# 8 times
8589
tasks = [add_new_prompt_refresh(i) for i in range(8)]

0 commit comments

Comments
 (0)