# pyright: enableExperimentalFeatures=true # PEP 712 is only active for pyright with the "enableExperimentalFeatures" setting enabled. import operator import re import sys from typing import ( TYPE_CHECKING, Callable, ClassVar, Final, Generic, NoReturn, Pattern, Protocol, Tuple, TypeVar, Union, final, ) import attrs if sys.version_info >= (3, 11): from typing import ParamSpec, Self, reveal_type else: from typing_extensions import ParamSpec, Self, reveal_type if sys.version_info >= (3, 10): from typing import ParamSpec else: from typing_extensions import ParamSpec if sys.version_info >= (3, 9): from typing import Annotated else: from typing_extensions import Annotated _MISSING = object() T = TypeVar("T") U = TypeVar("U") P = ParamSpec("P") P2 = ParamSpec("P2") base_model = attrs.define adapter = attrs.field # ===================================================================================================================== # ==== Implementation of Refined. # ===================================================================================================================== if TYPE_CHECKING: Refined = Annotated else: import operator from typing import _GenericAlias, _tp_cache, _type_check, _type_repr if sys.version_info >= (3, 12): from typing import Unpack else: from typing_extensions import Unpack if sys.version_info >= (3, 10): from typing import get_origin else: from typing_extensions import get_origin # Almost an exact reimplementation of Annotated. @final class _RefinedGenericAlias(_GenericAlias, _root=True): if TYPE_CHECKING: __origin__: type __refinements__: Tuple[object, ...] def __init__(self, origin: type, refinements: Tuple[object, ...]): if isinstance(origin, _RefinedGenericAlias): refinements = origin.__refinements__ + refinements origin = origin.__origin__ super().__init__(origin, origin) self.__refinements__ = refinements def copy_with(self, params: Tuple[object, ...]): if len(params) != 1: raise AssertionError new_type = params[0] return _RefinedGenericAlias(new_type, self.__refinements__) def __repr__(self): return f"Refined[{_type_repr(self.__origin__)}, {', '.join(repr(r) for r in self.__refinements__)}]" def __reduce__(self): return operator.getitem, (Refined, (self.__origin__, *self.__refinements__)) def __eq__(self, other: object, /): if isinstance(other, type(self)): if self.__origin__ != other.__origin__: return False return self.__refinements__ == other.__refinements__ return NotImplemented def __hash__(self): return hash((self.__origin__, self.__refinements)) @final class Refined: __slots__ = () def __new__(cls, *args: object, **kwargs: object) -> NoReturn: raise TypeError("Type Refined cannot be instantiated.") def __init_subclass__(cls, *args: object, **kwargs: object) -> NoReturn: raise TypeError(f"Cannot subclass {cls.__module__}.Refined") def __class_getitem__( cls, params: Tuple[type, Unpack[Tuple[Union["TypeRefinement", "ValueRefinement"], ...]]], ) -> _RefinedGenericAlias: if not isinstance(params, tuple): params = (params,) return cls._class_getitem_inner(cls, *params) @_tp_cache(typed=True) def _class_getitem_inner( cls, *params: Unpack[Tuple[type, Unpack[Tuple[Union["TypeRefinement", "ValueRefinement"], ...]]]], ) -> _RefinedGenericAlias: if len(params) < 2: raise TypeError("Refined[...] should be used with at least two arguments (a type and an annotation).") if (not isinstance(params[0], type)) and getattr(params[0], "__typing_is_unpacked_typevartuple__", False): raise TypeError("Refined[...] should not be used with an unpacked TypeVarTuple.") allowed_special_forms = {ClassVar, Final} if get_origin(params[0]) in allowed_special_forms: origin = params[0] else: msg = "Refined[t, ...]: t must be a type." origin = _type_check(params[0], msg) refinements = tuple(params[1:]) return _RefinedGenericAlias(origin, refinements) class TypeRefinement(Protocol): def __supports_type__(self, t: type) -> bool: ... class ValueRefinement(Protocol): def __supports_value__(self, o: object) -> bool: ... class NumCmp: _op_map: ClassVar = { "eq": operator.eq, "ne": operator.ne, "gt": operator.gt, "ge": operator.ge, "lt": operator.lt, "le": operator.le, } def __init__( self, eq: object = _MISSING, ne: object = _MISSING, gt: object = _MISSING, ge: object = _MISSING, lt: object = _MISSING, le: object = _MISSING, ): self.eq = eq self.ne = ne self.gt = gt self.ge = ge self.lt = lt self.le = le def __supports_value__(self, o: object) -> bool: cond = True for cmp_name, cmp_op in self._op_map.items(): if (cmp_val := getattr(self, cmp_name)) is not _MISSING: cond &= cmp_op(o, cmp_val) return cond class RePtrn: def __init__(self, pattern: Union[str, Pattern[str]]): self.pattern = pattern if isinstance(pattern, Pattern) else re.compile(pattern) def __supports_value__(self, o: str) -> bool: return self.pattern.match(o) is not None # ===================================================================================================================== # ==== Implementation of parse to superficially match the semantics of Pydantic's thing/use case. # ===================================================================================================================== class ValidationError(Exception): pass @final class Parser(Generic[P, T]): __slots__ = ("typer",) def __init__(self, typer: Callable[P, T]): self.typer = typer def __init_subclass__(cls, *args: object, **kwargs: object) -> NoReturn: raise TypeError(f"Cannot subclass {cls.__module__}.Parser") def __or__(self, other: "Parser[P2, T]", /) -> "Parser[P2, T]": if not isinstance(other, Parser): # pyright: ignore [reportUnnecessaryIsInstance] return NotImplemented def temp(*args: P2.args, **kwargs: P2.kwargs) -> T: result = object() for typer in (self.typer, other.typer): try: result = typer(*args, **kwargs) except ValidationError: print(f"Failed to parse {(args, kwargs)} with {typer}. Attempting next.") # noqa: T201 else: return result raise ValidationError return Parser(temp) def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: try: return self.typer(*args, **kwargs) except Exception as exc: # noqa: BLE001 raise ValidationError from exc def _convert(self, fn: Callable[[T], U]) -> "Parser[P, U]": def temp(*args: P.args, **kwargs: P.kwargs) -> U: return fn(self.typer(*args, **kwargs)) return Parser(temp) def transform(self, fn: Callable[[T], U]) -> "Parser[P, U]": return self._convert(fn) def parse(self, fn: Callable[[T], U]) -> "Parser[P, U]": return self._convert(fn) def ge(self, floor: int) -> Self: # XXX: Nonfunctional placeholder. return self def lt(self, ceil: int) -> Self: # XXX: Nonfunctional placeholder. return self def parse(tp: Callable[P, T]) -> Parser[P, T]: return Parser(tp) # ===================================================================================================================== # ==== Attempt at an example using the above. # ===================================================================================================================== # Pretend these classes are subclasses of pydantic.BaseModel instead of fresh classes being wrapped by class decorators. # This is what Pydantic wants their transformers and validators to look like. @base_model class Before: username: Annotated[str, parse(str).transform(str.lower)] birthday: Annotated[int, (parse(int) | parse(str).transform(str.strip).parse(int)).ge(0).lt(512)] age: Annotated[int, parse(int)] # This is an attrs class with PEP 712 active, and imo looks like a better alternative. @base_model class After: username: str = adapter(converter=parse(str).transform(str.lower)) birthday: Refined[int, NumCmp(ge=0, lt=512)] = adapter(converter=(parse(int) | parse(str).transform(str.strip).parse(int))) age: int = adapter(converter=parse(int)) def test() -> None: reveal_type(After.__init__) # Type of "After.__init__" is "(self: After, username: object, birthday: object, age: str | Buffer | SupportsInt | SupportsIndex | SupportsTrunc) -> None" ex = After(10, "1010", 1.0) reveal_type(ex.username) # Type of "ex.username" is "str" print(ex.username) reveal_type(ex.birthday) # Type of "ex.birthday" is "int" print(ex.birthday) reveal_type(ex.age) # Type of "ex.age" is "int" print(ex.age) if __name__ == "__main__": test()