Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish_extract_worker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
id-token: write
contents: write
env:
PYTHON_VERSION: 3.12
PYTHON_VERSION: 3.14
ASTRAL_VERSION: 0.11.6
steps:
- uses: actions/checkout@v6
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_extract_worker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ jobs:
test:
runs-on: ubuntu-latest
env:
PYTHON_VERSION: 3.12
ASTRAL_VERSION: 0.11.6
PYTHON_VERSION: 3.14
ASTRAL_VERSION: 0.11.24
steps:
- uses: actions/checkout@v6
- name: Setup Python project
Expand Down
44 changes: 26 additions & 18 deletions datashare-python/datashare_python/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import asyncio
import shutil
from asyncio import AbstractEventLoop
from collections.abc import AsyncGenerator, Generator, Iterator, Sequence
from collections.abc import AsyncGenerator, Generator, Sequence
from pathlib import Path

import aiohttp
Expand All @@ -10,7 +8,9 @@
from elasticsearch._async.helpers import async_streaming_bulk
from icij_common.es import DOC_ROOT_ID, ES_DOCUMENT_TYPE, ID, ESClient
from icij_common.test_utils import reset_env # noqa: F401
from pytest_asyncio import is_async_test
from temporalio import workflow
from temporalio.service import RPCError, RPCStatusCode

from datashare_python.config import (
DatashareClientConfig,
Expand Down Expand Up @@ -59,6 +59,13 @@
}


def pytest_collection_modifyitems(items: list) -> None:
pytest_asyncio_tests = (item for item in items if is_async_test(item))
session_scope_marker = pytest.mark.asyncio(loop_scope="session")
for async_test in pytest_asyncio_tests:
async_test.add_marker(session_scope_marker, append=False)


@activity_defn(name="mocked-act")
def mocked_act() -> None:
pass
Expand All @@ -81,16 +88,6 @@ def test_deps() -> list[ContextManagerFactory]:
return [set_es_client, set_task_client]


@pytest.fixture(scope="session")
def event_loop(
request: pytest.FixtureRequest, # noqa: ARG001
) -> Iterator[asyncio.AbstractEventLoop]:
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()


@pytest.fixture(scope="session")
def test_worker_config() -> WorkerConfig:
logging_config = LoggingConfig(
Expand All @@ -117,9 +114,7 @@ def test_worker_config_path(test_worker_config: WorkerConfig, tmpdir: Path) -> P

@pytest.fixture(scope="session")
async def worker_lifetime_deps(
event_loop: AbstractEventLoop,
test_deps: list[ContextManagerFactory],
test_worker_config: WorkerConfig,
test_deps: list[ContextManagerFactory], test_worker_config: WorkerConfig
) -> AsyncGenerator[None, None]:
worker_id = "test-worker-id"
ctx = "test application"
Expand All @@ -128,7 +123,6 @@ async def worker_lifetime_deps(
ctx=ctx,
worker_id=worker_id,
worker_config=test_worker_config,
event_loop=event_loop,
):
yield

Expand Down Expand Up @@ -174,11 +168,25 @@ async def test_task_client(
@pytest.fixture(scope="session")
async def test_temporal_client_session(
test_worker_config: WorkerConfig,
event_loop: AbstractEventLoop, # noqa: ARG001
) -> TemporalClient: # noqa: ANN001
return await test_worker_config.to_temporal_client()


@pytest.fixture
async def test_temporal_client(
test_temporal_client_session: TemporalClient,
) -> TemporalClient: # noqa: ANN001
client = test_temporal_client_session
async for wf in client.list_workflows():
try:
await client.get_workflow_handle(wf.id).terminate()
except RPCError as e:
if e.status != RPCStatusCode.NOT_FOUND:
raise

return client


@pytest.fixture
async def populate_es(
test_es_client: ESClient,
Expand Down
8 changes: 5 additions & 3 deletions datashare-python/datashare_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Any, ParamSpec, TypeVar
from uuid import uuid4

import nest_asyncio
import temporalio
from pydantic import ValidationError
from temporalio import activity, workflow
Expand Down Expand Up @@ -76,9 +75,12 @@ def to_progress(self) -> Progress:


class ActivityWithProgress:
def __init__(self, temporal_client: Client, event_loop: asyncio.AbstractEventLoop):
def __init__(
self,
temporal_client: Client,
event_loop: asyncio.AbstractEventLoop | None = None,
):
self._temporal_client = temporal_client
nest_asyncio.apply()
self._event_loop = event_loop


Expand Down
5 changes: 4 additions & 1 deletion datashare-python/datashare_python/worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
import logging
import os
Expand Down Expand Up @@ -157,7 +158,7 @@ async def worker_context(
workflows: list[type] | None = None,
worker_config: WorkerConfig,
client: TemporalClient,
event_loop: AbstractEventLoop,
event_loop: AbstractEventLoop | None = None,
task_queue: str,
dependencies: list[ContextManagerFactory] | None = None,
sandboxed: bool = True,
Expand All @@ -169,6 +170,8 @@ async def worker_context(
discovered.extend(workflows)
if dependencies is not None:
discovered.extend(dependencies)
if event_loop is None:
event_loop = asyncio.get_event_loop()
discovered.append(worker_config)
loggers = copy(worker_config.logging.loggers)
discovered_loggers = {_get_object_package(o).__name__ for o in discovered}
Expand Down
2 changes: 1 addition & 1 deletion datashare-python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ dependencies = [
"hatchling~=1.27",
"icij-common[elasticsearch]~=0.8.2",
"langcodes~=3.5",
"nest-asyncio~=1.6",
"orjson~=3.11",
"python-json-logger~=4.0",
"pyyaml~=6.0",
Expand Down Expand Up @@ -76,6 +75,7 @@ dev = [

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_debug = true
asyncio_default_fixture_loop_scope = "session"
markers = [
"integration",
Expand Down
22 changes: 4 additions & 18 deletions worker-template/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import asyncio
import uuid
from asyncio import AbstractEventLoop
from collections.abc import AsyncGenerator
from typing import Any

Expand All @@ -19,13 +17,14 @@
doc_1,
doc_2,
doc_3,
event_loop,
populate_es,
pytest_collection_modifyitems,
test_deps,
test_es_client,
test_es_client_session,
test_task_client,
test_task_client_session,
test_temporal_client,
test_temporal_client_session,
text_0,
text_1,
Expand Down Expand Up @@ -65,18 +64,13 @@ def test_worker_config() -> TranslateAndClassifyWorkerConfig:

@pytest.fixture(scope="session")
async def lifetime_deps(
event_loop: AbstractEventLoop, # noqa: F811
test_deps: list[ContextManagerFactory], # noqa: F811
test_worker_config: WorkerConfig,
) -> AsyncGenerator[None, Any]:
ctx = "unit test application"
worker_id = f"test-worker-{uuid.uuid4()}"
async with with_dependencies(
test_deps,
worker_config=test_worker_config,
event_loop=event_loop,
worker_id=worker_id,
ctx=ctx,
test_deps, worker_config=test_worker_config, worker_id=worker_id, ctx=ctx
):
yield

Expand All @@ -85,7 +79,6 @@ async def lifetime_deps(
async def workflows_worker(
test_worker_config: WorkerConfig, # noqa: F811
test_temporal_client_session: TemporalClient, # noqa: F811
event_loop: asyncio.AbstractEventLoop, # noqa: F811
test_deps: list[ContextManagerFactory], # noqa: F811
) -> AsyncGenerator[None, None]:
client = test_temporal_client_session
Expand All @@ -97,7 +90,6 @@ async def workflows_worker(
workflows=workflows,
worker_config=test_worker_config,
client=client,
event_loop=event_loop,
task_queue=task_queue,
dependencies=test_deps,
)
Expand All @@ -109,12 +101,11 @@ async def workflows_worker(
async def io_worker(
test_worker_config: WorkerConfig, # noqa: F811
test_temporal_client_session: TemporalClient, # noqa: F811
event_loop: asyncio.AbstractEventLoop, # noqa: F811
test_deps: list[ContextManagerFactory], # noqa: F811
) -> AsyncGenerator[None, None]:
client = test_temporal_client_session
worker_id = f"test-io-worker-{uuid.uuid4()}"
pong_activity = Pong(temporal_client=client, event_loop=event_loop)
pong_activity = Pong(temporal_client=client)
io_activities = [
pong_activity.pong,
CreateTranslationBatches.create_translation_batches,
Expand All @@ -126,7 +117,6 @@ async def io_worker(
activities=io_activities,
worker_config=test_worker_config,
client=client,
event_loop=event_loop,
task_queue=task_queue,
dependencies=test_deps,
)
Expand All @@ -138,7 +128,6 @@ async def io_worker(
async def translation_worker(
test_worker_config: WorkerConfig, # noqa: F811
test_temporal_client_session: TemporalClient, # noqa: F811
event_loop: asyncio.AbstractEventLoop, # noqa: F811
test_deps: list[ContextManagerFactory], # noqa: F811
) -> AsyncGenerator[None, None]:
client = test_temporal_client_session
Expand All @@ -150,7 +139,6 @@ async def translation_worker(
activities=translation_activities,
worker_config=test_worker_config,
client=client,
event_loop=event_loop,
task_queue=task_queue,
dependencies=test_deps,
)
Expand All @@ -162,7 +150,6 @@ async def translation_worker(
async def classification_worker(
test_worker_config: WorkerConfig,
test_temporal_client_session: TemporalClient, # noqa: F811
event_loop: asyncio.AbstractEventLoop, # noqa: F811
test_deps: list[ContextManagerFactory], # noqa: F811
) -> AsyncGenerator[None, None]:
client = test_temporal_client_session
Expand All @@ -174,7 +161,6 @@ async def classification_worker(
activities=classification_activities,
worker_config=test_worker_config,
client=client,
event_loop=event_loop,
task_queue=task_queue,
dependencies=test_deps,
)
Expand Down
6 changes: 3 additions & 3 deletions worker-template/tests/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
async def test_ping_workflow_e2e(
io_worker: Worker, # noqa: ARG001
workflows_worker: Worker, # noqa: ARG001
test_temporal_client_session: TemporalClient,
test_temporal_client: TemporalClient,
) -> None:
# Given
temporal_client = test_temporal_client_session
client = test_temporal_client
wf_id = f"ping-{uuid.uuid4()}"

# When
args = dict()
response = await temporal_client.execute_workflow(
response = await client.execute_workflow(
PingWorkflow,
args,
id=wf_id,
Expand Down
2 changes: 1 addition & 1 deletion workers/asr-worker/asr_worker/activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def postprocess_act(
)
logger.debug("wrote transcription for %s", t_path)
if progress is not None and event_loop is not None:
event_loop.run_until_complete(progress(i))
asyncio.run_coroutine_threadsafe(progress(i), event_loop).result()
return n_docs


Expand Down
1 change: 1 addition & 0 deletions workers/asr-worker/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ exclude = [

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_debug = true
asyncio_default_fixture_loop_scope = "session"
markers = [
"integration",
Expand Down
3 changes: 2 additions & 1 deletion workers/asr-worker/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
TEST_PROJECT,
clear_dirs,
doc_3,
event_loop,
index_docs,
populate_es,
pytest_collection_modifyitems,
test_es_client,
test_es_client_session,
test_temporal_client,
test_temporal_client_session,
test_worker_config,
typer_asyncio_patch,
Expand Down
Loading
Loading