Source code for core_etl.async_based
# -*- coding: utf-8 -*-
"""Async-based ETL abstract base class."""
import asyncio
from abc import ABC, abstractmethod
from typing import Any
from .base import IBaseETL
[docs]
class IAsyncETL(IBaseETL, ABC):
"""
Base class for an ETL task that need to process elements
in an asynchronous manner.
Note: ``execute()`` calls ``asyncio.run()`` internally and cannot be invoked
from within a running event loop. In async contexts use:
``await asyncio.to_thread(task.execute)``
"""
[docs]
def __init__(
self,
max_queue_size: int = 1000,
max_workers: int = 10,
**kwargs
) -> None:
super().__init__(**kwargs)
self.max_queue_size = max_queue_size
self.max_workers = max_workers
self._queue: asyncio.Queue = None # type: ignore[assignment]
self._processed_records = 0
[docs]
def pre_processing(self) -> None:
self._processed_records = 0
[docs]
def _execute(self, *args, **kwargs) -> int:
if self.max_workers < 1:
raise ValueError("max_workers must be at least 1")
return asyncio.run(
self._execute_async(*args, **kwargs)
)
[docs]
async def _execute_async(self, *_args, **_kwargs) -> int:
self._queue = asyncio.Queue(maxsize=self.max_queue_size)
# Consumers (workers) will process elements from
# the queue that are generated from producers...
consumers = [
asyncio.create_task(self._consume_record(f"WORKER_{i}"))
for i in range(1, self.max_workers + 1)
]
await self.produce_records()
await self._queue.join()
await self._stop()
# After sending stop signal we can wait for consumers...
await asyncio.gather(*consumers)
self.info(
"Processed %s elements asynchronously",
self._processed_records,
)
return self._processed_records
[docs]
@abstractmethod
async def produce_records(self):
"""
Must be implemented in subclasses. It must populate the ``asyncio.Queue``
with the records will be processed via consumers in ``_process_record``
function. Ensure: ``await queue.put(record)`` is called.
"""
[docs]
async def _consume_record(self, worker_id: str) -> None:
"""Pulls records from the queue and dispatches them to ``_process_record``."""
while True:
record = await self._queue.get()
if record is None:
self.info("[%s] exiting.", worker_id)
break
# Invoking concrete function must consume the record...
try:
await self._process_record(record)
self._processed_records += 1
except Exception as error: # pylint: disable=broad-exception-caught
self.error(
"[%s] failed to process record: %s",
worker_id, error,
)
finally:
self._queue.task_done()
[docs]
@abstractmethod
async def _process_record(self, record: Any) -> None:
"""Must be implemented by child classes"""
[docs]
async def _stop(self):
"""Send stop signal to all consumers/workers."""
for _ in range(self.max_workers):
await self._queue.put(None)