import sys
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Set, Type
from xsdata.exceptions import XmlContextError
from xsdata.formats.bindings import T
from xsdata.formats.dataclass.compat import class_types
from xsdata.formats.dataclass.models.builders import XmlMetaBuilder
from xsdata.formats.dataclass.models.elements import XmlMeta
from xsdata.models.enums import DataType
from xsdata.utils.constants import return_input
[docs]
class XmlContext:
"""
The service provider for binding operations' metadata.
:param element_name_generator: Default element name generator
:param attribute_name_generator: Default attribute name generator
:param class_type: Default class type `dataclasses`
:param models_package: Restrict auto locate to a specific package
"""
__slots__ = (
"element_name_generator",
"attribute_name_generator",
"class_type",
"cache",
"xsi_cache",
"sys_modules",
"models_package",
)
def __init__(
self,
element_name_generator: Callable = return_input,
attribute_name_generator: Callable = return_input,
class_type: str = "dataclasses",
models_package: Optional[str] = None,
):
self.element_name_generator = element_name_generator
self.attribute_name_generator = attribute_name_generator
self.class_type = class_types.get_type(class_type)
self.cache: Dict[Type, XmlMeta] = {}
self.xsi_cache: Dict[str, List[Type]] = defaultdict(list)
self.models_package = models_package
self.sys_modules = 0
def reset(self):
self.cache.clear()
self.xsi_cache.clear()
self.sys_modules = 0
def get_builder(
self, globalns: Optional[Dict[str, Callable]] = None
) -> XmlMetaBuilder:
return XmlMetaBuilder(
class_type=self.class_type,
element_name_generator=self.element_name_generator,
attribute_name_generator=self.attribute_name_generator,
globalns=globalns,
)
[docs]
def fetch(
self,
clazz: Type,
parent_ns: Optional[str] = None,
xsi_type: Optional[str] = None,
) -> XmlMeta:
"""
Fetch the model metadata of the given dataclass type, namespace and xsi
type.
:param clazz: The requested dataclass type
:param parent_ns: The inherited parent namespace
:param xsi_type: if present it means that the given clazz is
derived and the lookup procedure needs to check and match a
dataclass model to the qualified name instead
"""
meta = self.build(clazz, parent_ns)
subclass = None
if xsi_type and meta.target_qname != xsi_type:
subclass = self.find_subclass(clazz, xsi_type)
return self.build(subclass, parent_ns) if subclass else meta
[docs]
def build_xsi_cache(self):
"""Index all imported dataclasses by their xsi:type qualified name."""
if len(sys.modules) == self.sys_modules:
return
self.xsi_cache.clear()
builder = self.get_builder()
for clazz in self.get_subclasses(object):
if self.is_binding_model(clazz):
meta = builder.build_class_meta(clazz)
if meta.target_qname:
self.xsi_cache[meta.target_qname].append(clazz)
self.sys_modules = len(sys.modules)
def is_binding_model(self, clazz: Type[T]) -> bool:
if not self.class_type.is_model(clazz):
return False
return not self.models_package or (
hasattr(clazz, "__module__")
and isinstance(clazz.__module__, str)
and clazz.__module__.startswith(self.models_package)
)
[docs]
def find_types(self, qname: str) -> List[Type[T]]:
"""
Find all classes that match the given xsi:type qname.
- Ignores native schema types, xs:string, xs:float, xs:int, ...
- Rebuild cache if new modules were imported since last run
:param qname: Qualified name
"""
if not DataType.from_qname(qname):
self.build_xsi_cache()
if qname in self.xsi_cache:
return self.xsi_cache[qname]
return []
[docs]
def find_type(self, qname: str) -> Optional[Type[T]]:
"""
Return the most recently imported class that matches the given xsi:type
qname.
:param qname: Qualified name
"""
types: List[Type] = self.find_types(qname)
return types[-1] if types else None
[docs]
def find_type_by_fields(self, field_names: Set[str]) -> Optional[Type[T]]:
"""
Find a dataclass from all the imported modules that matches the given
list of field names.
:param field_names: A unique list of field names
"""
def get_field_diff(clazz: Type) -> int:
meta = self.cache[clazz]
local_names = {var.local_name for var in meta.get_all_vars()}
return len(local_names - field_names)
self.build_xsi_cache()
choices = [
(clazz, get_field_diff(clazz))
for types in self.xsi_cache.values()
for clazz in types
if self.local_names_match(field_names, clazz)
]
choices.sort(key=lambda x: (x[1], x[0].__name__))
return choices[0][0] if len(choices) > 0 else None
[docs]
def find_subclass(self, clazz: Type, qname: str) -> Optional[Type]:
"""
Compare all classes that match the given xsi:type qname and return the
first one that is either a subclass or shares the same parent class as
the original class.
:param clazz: The search dataclass type
:param qname: Qualified name
"""
types: List[Type] = self.find_types(qname)
for tp in types:
# Why would an xml node with have an xsi:type that points
# to parent class is beyond me but it happens, let's protect
# against that scenario <node xsi:type="nodeAbstract" />
if issubclass(clazz, tp):
continue
for tp_mro in tp.__mro__:
if tp_mro is not object and tp_mro in clazz.__mro__:
return tp
return None
[docs]
def build(
self,
clazz: Type,
parent_ns: Optional[str] = None,
globalns: Optional[Dict[str, Callable]] = None,
) -> XmlMeta:
"""
Fetch from cache or build the binding metadata for the given class and
parent namespace.
:param clazz: A dataclass type
:param parent_ns: The inherited parent namespace
"""
if clazz not in self.cache:
builder = self.get_builder(globalns)
self.cache[clazz] = builder.build(clazz, parent_ns)
return self.cache[clazz]
[docs]
def build_recursive(self, clazz: Type, parent_ns: Optional[str] = None):
"""Build the binding metadata for the given class and all of its
dependencies."""
if clazz not in self.cache:
meta = self.build(clazz, parent_ns)
for var in meta.get_all_vars():
types = var.element_types if var.elements else var.types
for tp in types:
if self.class_type.is_model(tp):
self.build_recursive(tp, meta.namespace)
def local_names_match(self, names: Set[str], clazz: Type) -> bool:
try:
meta = self.build(clazz)
local_names = {var.local_name for var in meta.get_all_vars()}
return not names.difference(local_names)
except (XmlContextError, NameError, TypeError):
# The dataclass includes unsupported typing annotations
# Let's remove it from xsi_cache
builder = self.get_builder()
target_qname = builder.build_class_meta(clazz).target_qname
if target_qname and target_qname in self.xsi_cache:
self.xsi_cache[target_qname].remove(clazz)
return False
[docs]
@classmethod
def is_derived(cls, obj: Any, clazz: Type) -> bool:
"""
Return whether the given obj is derived from the given dataclass type.
:param obj: A dataclass instance
:param clazz: A dataclass type
"""
if obj is None:
return False
if isinstance(obj, clazz):
return True
return any(x is not object and isinstance(obj, x) for x in clazz.__bases__)
@classmethod
def get_subclasses(cls, clazz: Type):
try:
for subclass in clazz.__subclasses__():
yield from cls.get_subclasses(subclass)
yield subclass
except TypeError:
pass