Source code for core_etl.async_based

# -*- coding: utf-8 -*-

"""Async-based ETL abstract base class."""

import asyncio
from abc import ABC, abstractmethod
from typing import Any

from .base import IBaseETL


[docs] class IAsyncETL(IBaseETL, ABC): """ Base class for an ETL task that need to process elements in an asynchronous manner. Note: ``execute()`` calls ``asyncio.run()`` internally and cannot be invoked from within a running event loop. In async contexts use: ``await asyncio.to_thread(task.execute)`` """
[docs] def __init__( self, max_queue_size: int = 1000, max_workers: int = 10, **kwargs ) -> None: super().__init__(**kwargs) self.max_queue_size = max_queue_size self.max_workers = max_workers self._queue: asyncio.Queue = None # type: ignore[assignment] self._processed_records = 0
[docs] def pre_processing(self) -> None: self._processed_records = 0
[docs] def _execute(self, *args, **kwargs) -> int: if self.max_workers < 1: raise ValueError("max_workers must be at least 1") return asyncio.run( self._execute_async(*args, **kwargs) )
[docs] async def _execute_async(self, *_args, **_kwargs) -> int: self._queue = asyncio.Queue(maxsize=self.max_queue_size) # Consumers (workers) will process elements from # the queue that are generated from producers... consumers = [ asyncio.create_task(self._consume_record(f"WORKER_{i}")) for i in range(1, self.max_workers + 1) ] await self.produce_records() await self._queue.join() await self._stop() # After sending stop signal we can wait for consumers... await asyncio.gather(*consumers) self.info( "Processed %s elements asynchronously", self._processed_records, ) return self._processed_records
[docs] @abstractmethod async def produce_records(self): """ Must be implemented in subclasses. It must populate the ``asyncio.Queue`` with the records will be processed via consumers in ``_process_record`` function. Ensure: ``await queue.put(record)`` is called. """
[docs] async def _consume_record(self, worker_id: str) -> None: """Pulls records from the queue and dispatches them to ``_process_record``.""" while True: record = await self._queue.get() if record is None: self.info("[%s] exiting.", worker_id) break # Invoking concrete function must consume the record... try: await self._process_record(record) self._processed_records += 1 except Exception as error: # pylint: disable=broad-exception-caught self.error( "[%s] failed to process record: %s", worker_id, error, ) finally: self._queue.task_done()
[docs] @abstractmethod async def _process_record(self, record: Any) -> None: """Must be implemented by child classes"""
[docs] async def _stop(self): """Send stop signal to all consumers/workers.""" for _ in range(self.max_workers): await self._queue.put(None)