Source code for gfw.common.beam.pipeline.dag.linear

"""This module contains a linear DAG implementation for Apache Beam pipelines."""

import logging

from functools import cached_property
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union

import apache_beam as beam

from apache_beam import PTransform
from apache_beam.pvalue import PCollection

from gfw.common.beam.transforms import SampleAndLogElements

from .base import Dag


[docs] class LinearDag(Dag): """A linear DAG implementation for Apache Beam pipelines. This DAG: 1. Applies multiple sources PTransforms and merges outputs into a single PCollection. 2. Applies a single core PTransform, with an optional side inputs PTransform. 3. Applies multiple sinks PTransforms. Args: sources: A list of PTransforms that read input data. core: The core PTransform that processes the data. side_inputs: A PTransform used to read side inputs that will be injected into the core transform. sinks: A list of PTransforms that write the output data. Attributes: output_paths: A list of output paths for each sink, if they contain the path attribute. """ def __init__( self, sources: Tuple[PTransform[Any, Any], ...] = (), core: Optional[PTransform[Any, Any]] = None, side_inputs: Optional[PTransform[Any, Any]] = None, sinks: Tuple[PTransform[Any, Any], ...] = (), ) -> None: """Initializes the :class:`LinearDag` object.""" self._sources = sources self._core = core or beam.Map(lambda x: x) self._sinks = sinks self._side_inputs = side_inputs
[docs] def apply(self, p: beam.Pipeline) -> PCollection: """Applies the linear DAG implementation to an Apache Beam pipeline.""" if self._side_inputs is not None: side_inputs = p | self._side_inputs self._core.set_side_inputs(side_inputs) # Source transformations inputs = [p | transform for transform in self._sources] if len(inputs) > 1: inputs = inputs | "JoinSources" >> beam.Flatten() else: inputs = inputs[0] # Core transformation outputs = inputs | self._core # Sink transformations for transform in self._sinks: outputs | transform if logging.getLogger().level == logging.DEBUG: inputs | "Log Inputs" >> SampleAndLogElements(message="Input: {e}", sample_size=1) outputs | "Log Outputs" >> SampleAndLogElements(message="Output: {e}", sample_size=1) return outputs
@cached_property def output_paths(self) -> List[Union[Path, str]]: """Resolves and returns a list of output paths for each sink in the pipeline.""" paths = [] for sink in self._sinks: paths.append(getattr(sink, "path", "Unknown.")) return paths