145 lines
4.3 KiB
Python
145 lines
4.3 KiB
Python
__all__ = ["fields_set", "is_set", "set_fields", "unset_fields", "with_fields_set"]
|
|
from dataclasses import ( # type: ignore
|
|
Field,
|
|
_FIELD,
|
|
_FIELDS,
|
|
_FIELD_INITVAR,
|
|
is_dataclass,
|
|
)
|
|
from functools import wraps
|
|
from inspect import signature
|
|
from typing import AbstractSet, Any, Collection, Set, Type, TypeVar, cast
|
|
|
|
from apischema.objects.fields import get_field_name
|
|
from apischema.utils import PREFIX
|
|
|
|
FIELDS_SET_ATTR = f"{PREFIX}fields_set"
|
|
_ALREADY_SET = f"{PREFIX}already_set"
|
|
|
|
Cls = TypeVar("Cls", bound=Type)
|
|
|
|
_fields_set_classes: Set[type] = set()
|
|
|
|
|
|
def support_fields_set(cls: type) -> bool:
|
|
return any(base in _fields_set_classes for base in cls.__mro__)
|
|
|
|
|
|
def with_fields_set(cls: Cls) -> Cls:
|
|
from apischema.metadata.keys import DEFAULT_AS_SET_METADATA
|
|
|
|
init_fields = set()
|
|
post_init_fields = set()
|
|
if is_dataclass(cls):
|
|
for field in getattr(cls, _FIELDS).values():
|
|
assert isinstance(field, Field)
|
|
if field._field_type == _FIELD_INITVAR: # type: ignore
|
|
init_fields.add(field.name)
|
|
if field._field_type == _FIELD and not field.init: # type: ignore
|
|
post_init_fields.add(field.name)
|
|
if field.metadata.get(DEFAULT_AS_SET_METADATA):
|
|
post_init_fields.add(field.name)
|
|
params = list(signature(cls.__init__).parameters)[1:]
|
|
old_new = cls.__new__
|
|
old_init = cls.__init__
|
|
old_setattr = cls.__setattr__
|
|
|
|
def new_new(*args, **kwargs):
|
|
if old_new is object.__new__:
|
|
obj = object.__new__(args[0])
|
|
else:
|
|
obj = old_new(*args, **kwargs)
|
|
# Initialize FIELD_SET_ATTR in order to prevent inherited class which override
|
|
# __init__ to raise in __setattr__
|
|
obj.__dict__[FIELDS_SET_ATTR] = set()
|
|
return obj
|
|
|
|
def new_init(self, *args, **kwargs):
|
|
prev_fields_set = self.__dict__.get(FIELDS_SET_ATTR, set()).copy()
|
|
self.__dict__[FIELDS_SET_ATTR] = set()
|
|
try:
|
|
old_init(self, *args, **kwargs)
|
|
except TypeError as err:
|
|
if str(err) == no_dataclass_init_error:
|
|
raise RuntimeError(dataclass_before_error) from None
|
|
else:
|
|
raise
|
|
arg_fields = {*params[: len(args)], *kwargs} - init_fields
|
|
self.__dict__[FIELDS_SET_ATTR] = prev_fields_set | arg_fields | post_init_fields
|
|
|
|
def new_setattr(self, attr, value):
|
|
try:
|
|
self.__dict__[FIELDS_SET_ATTR].add(attr)
|
|
except KeyError:
|
|
raise RuntimeError(dataclass_before_error) from None
|
|
old_setattr(self, attr, value)
|
|
|
|
for attr, old, new in [
|
|
("__new__", old_new, new_new),
|
|
("__init__", old_init, new_init),
|
|
("__setattr__", old_setattr, new_setattr),
|
|
]:
|
|
if hasattr(old, _ALREADY_SET):
|
|
continue
|
|
setattr(new, _ALREADY_SET, True)
|
|
setattr(cls, attr, wraps(old)(new)) # type: ignore
|
|
|
|
_fields_set_classes.add(cls)
|
|
return cls
|
|
|
|
|
|
no_dataclass_init_error = (
|
|
"object.__init__() takes exactly one argument (the instance to initialize)"
|
|
)
|
|
dataclass_before_error = (
|
|
f"{with_fields_set.__name__} must be put before dataclass decorator"
|
|
)
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def _field_names(fields: Collection) -> AbstractSet[str]:
|
|
result: Set[str] = set()
|
|
for field in fields:
|
|
result.add(get_field_name(field))
|
|
return result
|
|
|
|
|
|
def _fields_set(obj: Any) -> Set[str]:
|
|
try:
|
|
return getattr(obj, FIELDS_SET_ATTR)
|
|
except AttributeError:
|
|
raise TypeError(
|
|
f"Type {obj.__class__} is not decorated" f" with {with_fields_set.__name__}"
|
|
)
|
|
|
|
|
|
def set_fields(obj: T, *fields: Any, overwrite=False) -> T:
|
|
if overwrite:
|
|
_fields_set(obj).clear()
|
|
_fields_set(obj).update(map(get_field_name, fields))
|
|
return obj
|
|
|
|
|
|
def unset_fields(obj: T, *fields: Any) -> T:
|
|
_fields_set(obj).difference_update(map(get_field_name, fields))
|
|
return obj
|
|
|
|
|
|
# This could just be an alias with a specified type, but it's better handled by IDE
|
|
# like this
|
|
def fields_set(obj: Any) -> AbstractSet[str]:
|
|
return _fields_set(obj)
|
|
|
|
|
|
class FieldIsSet:
|
|
def __init__(self, obj: Any):
|
|
self.fields_set = fields_set(obj)
|
|
|
|
def __getattribute__(self, name: str) -> bool:
|
|
return name in object.__getattribute__(self, "fields_set")
|
|
|
|
|
|
def is_set(obj: T) -> T:
|
|
return cast(T, FieldIsSet(obj))
|