Created
September 3, 2024 14:18
-
-
Save vincentwyshan/e069101633254ddf22eebdf7322e97e3 to your computer and use it in GitHub Desktop.
Airflow plugin support config multiple cron expression schedule
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Any, Dict, List, Optional | |
| import pendulum | |
| from airflow.exceptions import AirflowTimetableInvalid | |
| from airflow.plugins_manager import AirflowPlugin | |
| from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable | |
| from croniter import croniter | |
| from pendulum import DateTime, timezone | |
| class MultiCronTimetable(Timetable): | |
| def __init__( | |
| self, | |
| cron_defs: List[str], | |
| timezone: str = "UTC", | |
| ): | |
| self.cron_defs = cron_defs | |
| self.timezone = timezone | |
| self.validate() | |
| @property | |
| def summary(self) -> str: | |
| """A short summary for the timetable. | |
| This is used to display the timetable in the web UI. A cron expression | |
| timetable, for example, can use this to display the expression. | |
| """ | |
| # offset = self.get_timezone_offset(self.timezone) | |
| # return " || ".join(self.cron_defs) + f" [{offset}]" | |
| return " || ".join(self.cron_defs) | |
| def infer_manual_data_interval(self, run_after: DateTime) -> DataInterval: | |
| """ | |
| Determines date interval for manually triggered runs. | |
| This is always trying to get the most previous recent DataInterval. | |
| """ | |
| crons = self.croniter_iters(run_after) | |
| ends = [cron.get_prev(DateTime) for cron in crons] | |
| end = max(ends) | |
| crons = self.croniter_iters(end) | |
| starts = [cron.get_prev(DateTime) for cron in crons] | |
| start = max(starts) | |
| return DataInterval(start=start, end=end) | |
| def next_dagrun_info( | |
| self, *, last_automated_data_interval: Optional[DataInterval], restriction: TimeRestriction | |
| ) -> Optional[DagRunInfo]: | |
| """ | |
| Determines when the DAG should be scheduled. | |
| """ | |
| if restriction.earliest is None: | |
| # No start_date. Don't schedule. | |
| return None | |
| if last_automated_data_interval is not None: | |
| start = last_automated_data_interval.end | |
| else: | |
| start = restriction.earliest | |
| # Make sure the start in same timezone as configured | |
| start = pendulum.from_timestamp(start.timestamp(), tz=self.timezone) | |
| end = min([cron.get_next(DateTime) for cron in self.croniter_iters(start)]) | |
| if not restriction.catchup: | |
| # If catchup is disabled, use now as the base time to determine the next run end | |
| base_time = pendulum.now(tz=self.timezone) | |
| _end = max([cron.get_prev(DateTime) for cron in self.croniter_iters(base_time)]) | |
| if _end > end: | |
| end = _end | |
| if restriction.latest and start > restriction.latest: | |
| # Over the DAG's scheduled end; don't schedule. | |
| return None | |
| return DagRunInfo.interval(start=start, end=end) | |
| def croniter_iters(self, base_datetime=None): | |
| if not base_datetime: | |
| tz = timezone(self.timezone) | |
| base_datetime = pendulum.now(tz) | |
| else: | |
| base_datetime = pendulum.from_timestamp(base_datetime.timestamp(), tz=self.timezone) | |
| return [croniter(expr, start_time=base_datetime) for expr in self.cron_defs] | |
| def validate(self) -> None: | |
| if not self.cron_defs: | |
| raise AirflowTimetableInvalid("At least one cron definition must be present") | |
| try: | |
| self.croniter_iters() | |
| except Exception as e: | |
| raise AirflowTimetableInvalid(str(e)) | |
| def serialize(self) -> Dict[str, Any]: | |
| """Serialize the timetable for JSON encoding. | |
| This is called during DAG serialization to store timetable information | |
| in the database. This should return a JSON-serializable dict that will | |
| be fed into ``deserialize`` when the DAG is deserialized. | |
| """ | |
| return dict( | |
| cron_defs=self.cron_defs, | |
| timezone=self.timezone, | |
| ) | |
| @classmethod | |
| def deserialize(cls, data: Dict[str, Any]) -> "MultiCronTimetable": | |
| """Deserialize a timetable from data. | |
| This is called when a serialized DAG is deserialized. ``data`` will be | |
| whatever was returned by ``serialize`` during DAG serialization. | |
| """ | |
| return cls(**data) | |
| def get_timezone_offset(self, timezone_name: str) -> float: | |
| now = pendulum.now(timezone_name) | |
| offset = now.utcoffset().total_seconds() / 3600 | |
| offset = round(offset, 2) | |
| if str(offset).endswith(".0"): | |
| offset = int(offset) | |
| if offset >= 0: | |
| return f"+{offset}" | |
| return str(offset) | |
| class MultiCronTimetablePlugin(AirflowPlugin): | |
| name = "multi_cron_timetable_plugin" | |
| timetables = [MultiCronTimetable] | |
| def test_next_dag_info(): | |
| cron_defs = ["0 0 * * *", "0 12 * * *"] | |
| timetable = MultiCronTimetable(cron_defs) | |
| utc = timezone("UTC") | |
| last_automated_data_interval = DataInterval( | |
| start=DateTime(2021, 1, 1, 0, 0, 0, tzinfo=utc), | |
| end=DateTime(2021, 1, 1, 12, 0, 0, tzinfo=utc), | |
| ) | |
| restriction = TimeRestriction( | |
| earliest=DateTime(2020, 1, 1, 0, 0, 0, tzinfo=utc), | |
| latest=DateTime(2023, 1, 2, 0, 0, 0, tzinfo=utc), | |
| catchup=True, | |
| ) | |
| dag_run_info = timetable.next_dagrun_info( | |
| last_automated_data_interval=last_automated_data_interval, restriction=restriction | |
| ) | |
| assert dag_run_info.data_interval.start == DateTime(2021, 1, 1, 12, 0, 0, tzinfo=utc) | |
| assert dag_run_info.data_interval.end == DateTime(2021, 1, 2, 0, 0, 0, tzinfo=utc) | |
| def test_infer_manual_data_interval(): | |
| cron_defs = ["0 0 * * *", "0 12 * * *"] | |
| utc = timezone("UTC") | |
| timetable = MultiCronTimetable(cron_defs) | |
| run_after = DateTime(2021, 1, 2, 11, 0, 0, tzinfo=utc) | |
| data_interval = timetable.infer_manual_data_interval(run_after) | |
| assert data_interval.start == DateTime(2021, 1, 1, 12, 0, 0, tzinfo=utc) | |
| assert data_interval.end == DateTime(2021, 1, 2, 0, 0, 0, tzinfo=utc) | |
| def test_infer_manual_data_interval1(): | |
| cron_defs = ["30 13 * * 5", "01 12 * * *"] | |
| utc = timezone("UTC") | |
| timetable = MultiCronTimetable(cron_defs) | |
| # This is Friday | |
| run_after = DateTime(2024, 6, 28, 14, 33, 0, tzinfo=utc) | |
| data_interval = timetable.infer_manual_data_interval(run_after) | |
| assert data_interval.start == DateTime(2024, 6, 28, 12, 1, 0, tzinfo=utc) | |
| assert data_interval.end == DateTime(2024, 6, 28, 13, 30, 0, tzinfo=utc) | |
| def test_get_tz_offset(): | |
| timetable = MultiCronTimetable(["1 * * * *"]) | |
| assert timetable.get_timezone_offset("Asia/Shanghai") == "+8" | |
| assert timetable.get_timezone_offset("UTC") == "+0" | |
| assert timetable.get_timezone_offset("America/Los_Angeles") == "-7" | |
| assert timetable.get_timezone_offset("Asia/Calcutta") == "+5.5" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment