92 lines
		
	
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			92 lines
		
	
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import warnings
 | |
| from dataclasses import dataclass
 | |
| from types import new_class
 | |
| from typing import Callable, Optional, TYPE_CHECKING, Tuple, Type, Union
 | |
| 
 | |
| from apischema.conversions import Conversion
 | |
| from apischema.conversions.conversions import ResolvedConversion
 | |
| from apischema.dataclasses import replace
 | |
| from apischema.utils import PREFIX, identity
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from apischema.deserialization.coercion import Coerce
 | |
| 
 | |
| Model = Union[Type, Callable[[], Type]]
 | |
| 
 | |
| 
 | |
| def check_model(origin: Type, model: Type):
 | |
|     if not isinstance(model, type):
 | |
|         raise TypeError("Dataclass model must be a dataclass")
 | |
|     if getattr(origin, "__parameters__", ()) != getattr(model, "__parameters__", ()):
 | |
|         raise TypeError("Dataclass model must have the same generic parameters")
 | |
| 
 | |
| 
 | |
| MODEL_ORIGIN_ATTR = f"{PREFIX}model_origin"
 | |
| 
 | |
| DATACLASS_ATTR = "_dataclass"
 | |
| 
 | |
| 
 | |
| @dataclass(frozen=True)
 | |
| class DataclassModel:
 | |
|     origin: Type
 | |
|     model: Model
 | |
|     fields_only: bool
 | |
| 
 | |
|     @property
 | |
|     def dataclass(self) -> Type:
 | |
|         if not hasattr(self, "_dataclass"):
 | |
|             origin = self.origin
 | |
|             if isinstance(self.model, type):
 | |
|                 assert check_model(origin, self.model) is None
 | |
|                 model = self.model
 | |
|             else:
 | |
|                 model = self.model()
 | |
|                 check_model(origin, model)
 | |
|             namespace = {"__new__": lambda _, *args, **kwargs: origin(*args, **kwargs)}
 | |
|             if not self.fields_only:
 | |
|                 namespace[MODEL_ORIGIN_ATTR] = origin
 | |
|             cls = new_class(
 | |
|                 model.__name__, (model,), exec_body=lambda ns: ns.update(namespace)
 | |
|             )
 | |
|             object.__setattr__(self, "_dataclass", cls)
 | |
|         return getattr(self, "_dataclass")
 | |
| 
 | |
| 
 | |
| def dataclass_model(
 | |
|     origin: Type,
 | |
|     model: Model,
 | |
|     *,
 | |
|     fields_only: bool = False,
 | |
|     additional_properties: Optional[bool] = None,
 | |
|     coercion: Optional["Coerce"] = None,
 | |
|     fall_back_on_default: Optional[bool] = None,
 | |
|     exclude_unset: Optional[bool] = None,
 | |
| ) -> Tuple[Conversion, Conversion]:
 | |
|     warnings.warn(
 | |
|         "dataclass_model is deprecated, use set_object_fields instead",
 | |
|         DeprecationWarning,
 | |
|     )
 | |
|     if isinstance(model, type):
 | |
|         check_model(origin, model)
 | |
| 
 | |
|     model_type = DataclassModel(origin, model, fields_only)
 | |
|     return Conversion(identity, source=model_type, target=origin), Conversion(
 | |
|         identity, source=origin, target=model_type
 | |
|     )
 | |
| 
 | |
| 
 | |
| def has_model_origin(cls: Type) -> bool:
 | |
|     return hasattr(cls, MODEL_ORIGIN_ATTR)
 | |
| 
 | |
| 
 | |
| def get_model_origin(cls: Type) -> Type:
 | |
|     return getattr(cls, MODEL_ORIGIN_ATTR)
 | |
| 
 | |
| 
 | |
| def handle_dataclass_model(conversion: ResolvedConversion) -> ResolvedConversion:
 | |
|     conv: Conversion = conversion
 | |
|     if isinstance(conv.source, DataclassModel):
 | |
|         conv = replace(conv, source=conv.source.dataclass)
 | |
|     if isinstance(conv.target, DataclassModel):
 | |
|         conv = replace(conv, target=conv.target.dataclass)
 | |
|     return ResolvedConversion(conv)
 | 
