Skip to content

Instantly share code, notes, and snippets.

@vincentwyshan
Created September 3, 2024 14:18
Show Gist options
  • Select an option

  • Save vincentwyshan/e069101633254ddf22eebdf7322e97e3 to your computer and use it in GitHub Desktop.

Select an option

Save vincentwyshan/e069101633254ddf22eebdf7322e97e3 to your computer and use it in GitHub Desktop.
Airflow plugin support config multiple cron expression schedule
from typing import Any, Dict, List, Optional
import pendulum
from airflow.exceptions import AirflowTimetableInvalid
from airflow.plugins_manager import AirflowPlugin
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
from croniter import croniter
from pendulum import DateTime, timezone
class MultiCronTimetable(Timetable):
def __init__(
self,
cron_defs: List[str],
timezone: str = "UTC",
):
self.cron_defs = cron_defs
self.timezone = timezone
self.validate()
@property
def summary(self) -> str:
"""A short summary for the timetable.
This is used to display the timetable in the web UI. A cron expression
timetable, for example, can use this to display the expression.
"""
# offset = self.get_timezone_offset(self.timezone)
# return " || ".join(self.cron_defs) + f" [{offset}]"
return " || ".join(self.cron_defs)
def infer_manual_data_interval(self, run_after: DateTime) -> DataInterval:
"""
Determines date interval for manually triggered runs.
This is always trying to get the most previous recent DataInterval.
"""
crons = self.croniter_iters(run_after)
ends = [cron.get_prev(DateTime) for cron in crons]
end = max(ends)
crons = self.croniter_iters(end)
starts = [cron.get_prev(DateTime) for cron in crons]
start = max(starts)
return DataInterval(start=start, end=end)
def next_dagrun_info(
self, *, last_automated_data_interval: Optional[DataInterval], restriction: TimeRestriction
) -> Optional[DagRunInfo]:
"""
Determines when the DAG should be scheduled.
"""
if restriction.earliest is None:
# No start_date. Don't schedule.
return None
if last_automated_data_interval is not None:
start = last_automated_data_interval.end
else:
start = restriction.earliest
# Make sure the start in same timezone as configured
start = pendulum.from_timestamp(start.timestamp(), tz=self.timezone)
end = min([cron.get_next(DateTime) for cron in self.croniter_iters(start)])
if not restriction.catchup:
# If catchup is disabled, use now as the base time to determine the next run end
base_time = pendulum.now(tz=self.timezone)
_end = max([cron.get_prev(DateTime) for cron in self.croniter_iters(base_time)])
if _end > end:
end = _end
if restriction.latest and start > restriction.latest:
# Over the DAG's scheduled end; don't schedule.
return None
return DagRunInfo.interval(start=start, end=end)
def croniter_iters(self, base_datetime=None):
if not base_datetime:
tz = timezone(self.timezone)
base_datetime = pendulum.now(tz)
else:
base_datetime = pendulum.from_timestamp(base_datetime.timestamp(), tz=self.timezone)
return [croniter(expr, start_time=base_datetime) for expr in self.cron_defs]
def validate(self) -> None:
if not self.cron_defs:
raise AirflowTimetableInvalid("At least one cron definition must be present")
try:
self.croniter_iters()
except Exception as e:
raise AirflowTimetableInvalid(str(e))
def serialize(self) -> Dict[str, Any]:
"""Serialize the timetable for JSON encoding.
This is called during DAG serialization to store timetable information
in the database. This should return a JSON-serializable dict that will
be fed into ``deserialize`` when the DAG is deserialized.
"""
return dict(
cron_defs=self.cron_defs,
timezone=self.timezone,
)
@classmethod
def deserialize(cls, data: Dict[str, Any]) -> "MultiCronTimetable":
"""Deserialize a timetable from data.
This is called when a serialized DAG is deserialized. ``data`` will be
whatever was returned by ``serialize`` during DAG serialization.
"""
return cls(**data)
def get_timezone_offset(self, timezone_name: str) -> float:
now = pendulum.now(timezone_name)
offset = now.utcoffset().total_seconds() / 3600
offset = round(offset, 2)
if str(offset).endswith(".0"):
offset = int(offset)
if offset >= 0:
return f"+{offset}"
return str(offset)
class MultiCronTimetablePlugin(AirflowPlugin):
name = "multi_cron_timetable_plugin"
timetables = [MultiCronTimetable]
def test_next_dag_info():
cron_defs = ["0 0 * * *", "0 12 * * *"]
timetable = MultiCronTimetable(cron_defs)
utc = timezone("UTC")
last_automated_data_interval = DataInterval(
start=DateTime(2021, 1, 1, 0, 0, 0, tzinfo=utc),
end=DateTime(2021, 1, 1, 12, 0, 0, tzinfo=utc),
)
restriction = TimeRestriction(
earliest=DateTime(2020, 1, 1, 0, 0, 0, tzinfo=utc),
latest=DateTime(2023, 1, 2, 0, 0, 0, tzinfo=utc),
catchup=True,
)
dag_run_info = timetable.next_dagrun_info(
last_automated_data_interval=last_automated_data_interval, restriction=restriction
)
assert dag_run_info.data_interval.start == DateTime(2021, 1, 1, 12, 0, 0, tzinfo=utc)
assert dag_run_info.data_interval.end == DateTime(2021, 1, 2, 0, 0, 0, tzinfo=utc)
def test_infer_manual_data_interval():
cron_defs = ["0 0 * * *", "0 12 * * *"]
utc = timezone("UTC")
timetable = MultiCronTimetable(cron_defs)
run_after = DateTime(2021, 1, 2, 11, 0, 0, tzinfo=utc)
data_interval = timetable.infer_manual_data_interval(run_after)
assert data_interval.start == DateTime(2021, 1, 1, 12, 0, 0, tzinfo=utc)
assert data_interval.end == DateTime(2021, 1, 2, 0, 0, 0, tzinfo=utc)
def test_infer_manual_data_interval1():
cron_defs = ["30 13 * * 5", "01 12 * * *"]
utc = timezone("UTC")
timetable = MultiCronTimetable(cron_defs)
# This is Friday
run_after = DateTime(2024, 6, 28, 14, 33, 0, tzinfo=utc)
data_interval = timetable.infer_manual_data_interval(run_after)
assert data_interval.start == DateTime(2024, 6, 28, 12, 1, 0, tzinfo=utc)
assert data_interval.end == DateTime(2024, 6, 28, 13, 30, 0, tzinfo=utc)
def test_get_tz_offset():
timetable = MultiCronTimetable(["1 * * * *"])
assert timetable.get_timezone_offset("Asia/Shanghai") == "+8"
assert timetable.get_timezone_offset("UTC") == "+0"
assert timetable.get_timezone_offset("America/Los_Angeles") == "-7"
assert timetable.get_timezone_offset("Asia/Calcutta") == "+5.5"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment