#!/usr/bin/env python
# cardinal_pythonlib/sqlalchemy/orm_inspect.py
"""
===============================================================================
Original code copyright (C) 2009-2022 Rudolf Cardinal (rudolf@pobox.com).
This file is part of cardinal_pythonlib.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
===============================================================================
**Functions to inspect and copy SQLAlchemy ORM objects.**
"""
from typing import (
Dict,
Generator,
List,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
# noinspection PyProtectedMember
from sqlalchemy import inspect
from sqlalchemy.orm.base import class_mapper
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.relationships import RelationshipProperty
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.schema import Column, MetaData
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.sql.visitors import Visitable
from sqlalchemy.util import OrderedProperties
from cardinal_pythonlib.classes import gen_all_subclasses
from cardinal_pythonlib.enumlike import OrderedNamespace
from cardinal_pythonlib.dicts import reversedict
from cardinal_pythonlib.logs import get_brace_style_log_with_null_handler
if TYPE_CHECKING:
from sqlalchemy.orm.state import InstanceState
from sqlalchemy.sql.schema import Table
log = get_brace_style_log_with_null_handler(__name__)
# =============================================================================
# Constants
# =============================================================================
VisitableType = Type[Visitable] # for SQLAlchemy 2.0
# =============================================================================
# Creating ORM objects conveniently, etc.
# =============================================================================
[docs]def coltype_as_typeengine(
coltype: Union[VisitableType, TypeEngine]
) -> TypeEngine:
"""
Instances of SQLAlchemy column types are subclasses of ``TypeEngine``.
It's possible to specify column types either as such instances, or as the
class type. This function ensures that such classes are converted to
instances.
To explain: you can specify columns like
.. code-block:: python
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import Integer, String, TypeEngine
a = Column("a", Integer)
b = Column("b", Integer())
c = Column("c", String(length=50))
# In SQLAlchemy to 1.4:
isinstance(Integer, TypeEngine) # False
isinstance(Integer(), TypeEngine) # True
isinstance(String(length=50), TypeEngine) # True
type(Integer) # <class 'sqlalchemy.sql.visitors.VisitableType'>
type(Integer()) # <class 'sqlalchemy.sql.sqltypes.Integer'>
type(String) # <class 'sqlalchemy.sql.visitors.VisitableType'>
type(String(length=50)) # <class 'sqlalchemy.sql.sqltypes.String'>
# In SQLAlchemy 2.0, VisitableType has gone. Though there is Visitable.
# (So we can also recreate VisitableType, as above, although only for
# type hints, not e.g. isinstance.)
from sqlalchemy.sql.visitors import Visitable
isinstance(Integer, TypeEngine) # False
isinstance(Integer(), TypeEngine) # True
isinstance(String(length=50), TypeEngine) # True
type(Integer) # <class 'type'>
issubclass(Integer, Visitable) # True
type(Integer()) # <class 'sqlalchemy.sql.sqltypes.Integer'>
type(String) # <class 'type'>
issubclass(String, Visitable) # True
type(String(length=50)) # <class 'sqlalchemy.sql.sqltypes.String'>
This function coerces things to a :class:`TypeEngine`.
"""
if isinstance(coltype, TypeEngine):
return coltype
instance = coltype()
assert isinstance(instance, TypeEngine)
return instance
# =============================================================================
# SqlAlchemyAttrDictMixin
# =============================================================================
[docs]class SqlAlchemyAttrDictMixin(object):
"""
Mixin to:
- get a plain dictionary-like object (with attributes so we can use ``x.y``
rather than ``x['y']``) from an SQLAlchemy ORM object
- make a nice ``repr()`` default, maintaining field order
See https://stackoverflow.com/questions/2537471 and in particular
https://stackoverflow.com/questions/2441796.
"""
[docs] def get_attrdict(self) -> OrderedNamespace:
"""
Returns what looks like a plain object with the values of the
SQLAlchemy ORM object.
"""
# noinspection PyUnresolvedReferences
columns = self.__table__.columns.keys()
values = (getattr(self, x) for x in columns)
zipped = zip(columns, values)
return OrderedNamespace(zipped)
def __repr__(self) -> str:
return "<{classname}({kvp})>".format(
classname=type(self).__name__,
kvp=", ".join(
f"{k}={v!r}" for k, v in self.get_attrdict().items()
),
)
[docs] @classmethod
def from_attrdict(cls, attrdict: OrderedNamespace) -> object:
"""
Builds a new instance of the ORM object from values in an attrdict.
"""
dictionary = attrdict.__dict__
# noinspection PyArgumentList
return cls(**dictionary)
# =============================================================================
# Traverse ORM relationships (SQLAlchemy ORM)
# =============================================================================
[docs]def walk_orm_tree(
obj,
debug: bool = False,
seen: Set = None,
skip_relationships_always: List[str] = None,
skip_relationships_by_tablename: Dict[str, List[str]] = None,
skip_all_relationships_for_tablenames: List[str] = None,
skip_all_objects_for_tablenames: List[str] = None,
) -> Generator[object, None, None]:
"""
Starting with a SQLAlchemy ORM object, this function walks a
relationship tree, yielding each of the objects once.
To skip attributes by name, put the attribute name(s) in
``skip_attrs_always``. To skip by table name, pass
``skip_attrs_by_tablename`` as e.g.
.. code-block:: python
{'sometable': ['attr1_to_skip', 'attr2_to_skip']}
Args:
obj: the SQLAlchemy ORM object to walk
debug: be verbose
seen: usually ``None``, but can be a set of objects marked as "already
seen"; if an object is in this set, it is skipped
skip_relationships_always: relationships are skipped if the
relationship has a name in this (optional) list
skip_relationships_by_tablename: optional dictionary mapping table
names (keys) to relationship attribute names (values); if the
"related table"/"relationship attribute" pair are in this
dictionary, the relationship is skipped
skip_all_relationships_for_tablenames: relationships are skipped if the
the related table has a name in this (optional) list
skip_all_objects_for_tablenames: if the object belongs to a table whose
name is in this (optional) list, the object is skipped
Yields:
SQLAlchemy ORM objects (including the starting object)
"""
# http://docs.sqlalchemy.org/en/latest/faq/sessions.html#faq-walk-objects
skip_relationships_always = (
skip_relationships_always or []
) # type: List[str]
skip_relationships_by_tablename = (
skip_relationships_by_tablename or {}
) # type: Dict[str, List[str]]
skip_all_relationships_for_tablenames = (
skip_all_relationships_for_tablenames or []
) # type: List[str]
skip_all_objects_for_tablenames = (
skip_all_objects_for_tablenames or []
) # type: List[str]
stack = [obj]
if seen is None:
seen = set()
while stack:
obj = stack.pop(0)
if obj in seen:
continue
tablename = obj.__tablename__
if tablename in skip_all_objects_for_tablenames:
continue
seen.add(obj)
if debug:
log.debug("walk: yielding {!r}", obj)
yield obj
insp = inspect(obj) # type: InstanceState
for (
relationship
) in insp.mapper.relationships: # type: RelationshipProperty
attrname = relationship.key
# Skip?
if attrname in skip_relationships_always:
continue
if tablename in skip_all_relationships_for_tablenames:
continue
if (
tablename in skip_relationships_by_tablename
and attrname in skip_relationships_by_tablename[tablename]
):
continue
# Process relationship
if debug:
log.debug("walk: following relationship {}", relationship)
related = getattr(obj, attrname)
if debug and related:
log.debug("walk: queueing {!r}", related)
if relationship.uselist:
stack.extend(related)
elif related is not None:
stack.append(related)
# =============================================================================
# deepcopy an SQLAlchemy object
# =============================================================================
# Use case: object X is in the database; we want to clone it to object Y,
# which we can then save to the database, i.e. copying all SQLAlchemy field
# attributes of X except its PK. We also want it to copy anything that is
# dependent upon X, i.e. traverse relationships.
#
# https://groups.google.com/forum/#!topic/sqlalchemy/wb2M_oYkQdY
# https://groups.google.com/forum/#!searchin/sqlalchemy/cascade%7Csort:date/sqlalchemy/eIOkkXwJ-Ms/JLnpI2wJAAAJ # noqa: E501
[docs]def copy_sqla_object(
obj: object,
omit_fk: bool = True,
omit_pk: bool = True,
omit_attrs: List[str] = None,
debug: bool = False,
) -> object:
"""
Given an SQLAlchemy object, creates a new object (FOR WHICH THE OBJECT
MUST SUPPORT CREATION USING ``__init__()`` WITH NO PARAMETERS), and copies
across all attributes, omitting PKs (by default), FKs (by default), and
relationship attributes (always omitted).
Args:
obj: the object to copy
omit_fk: omit foreign keys (FKs)?
omit_pk: omit primary keys (PKs)?
omit_attrs: attributes (by name) not to copy
debug: be verbose
Returns:
a new copy of the object
"""
omit_attrs = omit_attrs or [] # type: List[str]
cls = type(obj)
mapper = class_mapper(cls)
newobj = cls() # not: cls.__new__(cls)
rel_keys = set([c.key for c in mapper.relationships])
prohibited = rel_keys
if omit_pk:
pk_keys = set([c.key for c in mapper.primary_key])
prohibited |= pk_keys
if omit_fk:
fk_keys = set([c.key for c in mapper.columns if c.foreign_keys])
prohibited |= fk_keys
prohibited |= set(omit_attrs)
if debug:
log.debug("copy_sqla_object: skipping: {}", prohibited)
for k in [
p.key for p in mapper.iterate_properties if p.key not in prohibited
]:
try:
value = getattr(obj, k)
if debug:
log.debug(
"copy_sqla_object: processing attribute {} = {}", k, value
)
setattr(newobj, k, value)
except AttributeError:
if debug:
log.debug("copy_sqla_object: failed attribute {}", k)
pass
return newobj
[docs]def rewrite_relationships(
oldobj: object,
newobj: object,
objmap: Dict[object, object],
debug: bool = False,
skip_table_names: List[str] = None,
) -> None:
"""
A utility function only.
Used in copying objects between SQLAlchemy sessions.
Both ``oldobj`` and ``newobj`` are SQLAlchemy instances. The instance
``newobj`` is already a copy of ``oldobj`` but we wish to rewrite its
relationships, according to the map ``objmap``, which maps old to new
objects.
For example:
- Suppose a source session has a Customer record and a Sale record
containing ``sale.customer_id``, a foreign key to Customer.
- We may have corresponding Python SQLAlchemy ORM objects
``customer_1_src`` and ``sale_1_src``.
- We copy them into a destination database, where their Python ORM objects
are ``customer_1_dest`` and ``sale_1_dest``.
- In the process we set up an object map looking like:
.. code-block:: none
Old session New session
-------------------------------
customer_1_src customer_1_dest
sale_1_src sale_1_dest
- Now, we wish to make ``sale_1_dest`` have a relationship to
``customer_1_dest``, in the same way that ``sale_1_src`` has a
relationship to ``customer_1_src``. This function will modify
``sale_1_dest`` accordingly, given this object map. It will observe that
``sale_1_src`` (here ``oldobj``) has a relationship to
``customer_1_src``; it will note that ``objmap`` maps ``customer_1_src``
to ``customer_1_dest``; it will create the relationship from
``sale_1_dest`` (here ``newobj``) to ``customer_1_dest``.
Args:
oldobj: SQLAlchemy ORM object to read from
newobj: SQLAlchemy ORM object to write to
objmap: dictionary mapping "source" objects to their corresponding
"destination" object.
debug: be verbose
skip_table_names: if a related table's name is in this (optional) list,
that relationship is skipped
"""
skip_table_names = skip_table_names or [] # type: List[str]
insp = inspect(oldobj) # type: InstanceState
# insp.mapper.relationships is of type
# sqlalchemy.utils._collections.ImmutableProperties, which is basically
# a sort of AttrDict.
for (
attrname_rel
) in (
insp.mapper.relationships.items()
): # type: Tuple[str, RelationshipProperty]
attrname = attrname_rel[0]
rel_prop = attrname_rel[1]
if rel_prop.viewonly:
if debug:
log.debug("Skipping viewonly relationship")
continue # don't attempt to write viewonly relationships
related_class = rel_prop.mapper.class_
related_table_name = related_class.__tablename__ # type: str
if related_table_name in skip_table_names:
if debug:
log.debug(
"Skipping relationship for related table {!r}",
related_table_name,
)
continue
# The relationship is an abstract object (so getting the
# relationship from the old object and from the new, with e.g.
# newrel = newinsp.mapper.relationships[oldrel.key],
# yield the same object. All we need from it is the key name.
# rel_key = rel.key # type: str
# ... but also available from the mapper as attrname, above
related_old = getattr(oldobj, attrname)
if rel_prop.uselist:
related_new = [objmap[r] for r in related_old]
elif related_old is not None:
related_new = objmap[related_old]
else:
related_new = None
if debug:
log.debug(
"rewrite_relationships: relationship {} -> {}",
attrname,
related_new,
)
setattr(newobj, attrname, related_new)
[docs]def deepcopy_sqla_objects(
startobjs: List[object],
session: Session,
flush: bool = True,
debug: bool = False,
debug_walk: bool = True,
debug_rewrite_rel: bool = False,
objmap: Dict[object, object] = None,
) -> None:
"""
Makes a copy of the specified SQLAlchemy ORM objects, inserting them into a
new session.
This function operates in several passes:
1. Walk the ORM tree through all objects and their relationships, copying
every object thus found (via :func:`copy_sqla_object`, without their
relationships), and building a map from each source-session object to
its equivalent destination-session object.
2. Work through all the destination objects, rewriting their relationships
(via :func:`rewrite_relationships`) so they relate to each other (rather
than their source-session brethren).
3. Insert all the destination-session objects into the destination session.
For this to succeed, every object must take an ``__init__`` call with no
arguments (see :func:`copy_sqla_object`). (We can't specify the required
``args``/``kwargs``, since we are copying a tree of arbitrary objects.)
Args:
startobjs: SQLAlchemy ORM objects to copy
session: destination SQLAlchemy :class:`Session` into which to insert
the copies
flush: flush the session when we've finished?
debug: be verbose?
debug_walk: be extra verbose when walking the ORM tree?
debug_rewrite_rel: be extra verbose when rewriting relationships?
objmap: starting object map from source-session to destination-session
objects (see :func:`rewrite_relationships` for more detail);
usually ``None`` to begin with.
"""
if objmap is None:
objmap = {} # keys = old objects, values = new objects
if debug:
log.debug("deepcopy_sqla_objects: pass 1: create new objects")
# Pass 1: iterate through all objects. (Can't guarantee to get
# relationships correct until we've done this, since we don't know whether
# or where the "root" of the PK tree is.)
seen = set()
for startobj in startobjs:
for oldobj in walk_orm_tree(startobj, seen=seen, debug=debug_walk):
if debug:
log.debug("deepcopy_sqla_objects: copying {}", oldobj)
newobj = copy_sqla_object(oldobj, omit_pk=True, omit_fk=True)
# Don't insert the new object into the session here; it may trigger
# an autoflush as the relationships are queried, and the new
# objects are not ready for insertion yet (as their relationships
# aren't set).
# Note also the session.no_autoflush option:
# "sqlalchemy.exc.OperationalError: (raised as a result of Query-
# invoked autoflush; consider using a session.no_autoflush block if
# this flush is occurring prematurely)..."
objmap[oldobj] = newobj
# Pass 2: set all relationship properties.
if debug:
log.debug("deepcopy_sqla_objects: pass 2: set relationships")
for oldobj, newobj in objmap.items():
if debug:
log.debug("deepcopy_sqla_objects: newobj: {}", newobj)
rewrite_relationships(oldobj, newobj, objmap, debug=debug_rewrite_rel)
# Now we can do session insert.
if debug:
log.debug("deepcopy_sqla_objects: pass 3: insert into session")
for newobj in objmap.values():
session.add(newobj)
# Done
if debug:
log.debug("deepcopy_sqla_objects: done")
if flush:
session.flush()
[docs]def deepcopy_sqla_object(
startobj: object,
session: Session,
flush: bool = True,
debug: bool = False,
debug_walk: bool = False,
debug_rewrite_rel: bool = False,
objmap: Dict[object, object] = None,
) -> object:
"""
Makes a copy of the object, inserting it into ``session``.
Uses :func:`deepcopy_sqla_objects` (q.v.).
A problem is the creation of duplicate dependency objects if you call it
repeatedly.
Optionally, if you pass the objmap in (which maps old to new objects), you
can call this function repeatedly to clone a related set of objects...
... no, that doesn't really work, as it doesn't visit parents before
children. The :func:`cardinal_pythonlib.sqlalchemy.merge_db.merge_db`
function does that properly.
Args:
startobj: SQLAlchemy ORM object to deep-copy
session: see :func:`deepcopy_sqla_objects`
flush: see :func:`deepcopy_sqla_objects`
debug: see :func:`deepcopy_sqla_objects`
debug_walk: see :func:`deepcopy_sqla_objects`
debug_rewrite_rel: see :func:`deepcopy_sqla_objects`
objmap: see :func:`deepcopy_sqla_objects`
Returns:
the copied object matching ``startobj``
"""
if objmap is None:
objmap = {} # keys = old objects, values = new objects
deepcopy_sqla_objects(
startobjs=[startobj],
session=session,
flush=flush,
debug=debug,
debug_walk=debug_walk,
debug_rewrite_rel=debug_rewrite_rel,
objmap=objmap,
)
return objmap[startobj] # returns the new object matching startobj
# =============================================================================
# Get Columns from an ORM instance
# =============================================================================
[docs]def gen_columns(obj) -> Generator[Tuple[str, Column], None, None]:
"""
Asks a SQLAlchemy ORM object: "what are your SQLAlchemy columns?"
Yields tuples of ``(attr_name, Column)`` from an SQLAlchemy ORM object
instance. Also works with the corresponding SQLAlchemy ORM class. Examples:
.. code-block:: python
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import Integer
Base = declarative_base()
class MyClass(Base):
__tablename__ = "mytable"
pk = Column("pk", Integer, primary_key=True, autoincrement=True)
a = Column("a", Integer)
x = MyClass()
list(gen_columns(x))
list(gen_columns(MyClass))
"""
for attrname, column in dict(inspect(obj).mapper.columns).items():
# NB: column.name is the SQL column name, not the attribute name
yield attrname, column
[docs]def get_pk_attrnames(obj) -> List[str]:
"""
Asks an SQLAlchemy ORM object: "what are your primary key(s)?"
Args:
obj: SQLAlchemy ORM object
Returns:
list of attribute names of primary-key columns
"""
return [
attrname for attrname, column in gen_columns(obj) if column.primary_key
]
[docs]def gen_columns_for_uninstrumented_class(
cls: Type,
) -> Generator[Tuple[str, Column], None, None]:
"""
Generate ``(attr_name, Column)`` tuples from an UNINSTRUMENTED class, i.e.
one that does not inherit from ``declarative_base()``. Use this for mixins
of that kind.
SUBOPTIMAL. May produce warnings like:
.. code-block:: none
SAWarning: Unmanaged access of declarative attribute id from non-mapped class GenericTabletRecordMixin
Try to use :func:`gen_columns` instead.
""" # noqa: E501
for attrname in dir(cls):
potential_column = getattr(cls, attrname)
if isinstance(potential_column, Column):
yield attrname, potential_column
[docs]def attrname_to_colname_dict(cls) -> Dict[str, str]:
"""
Asks an SQLAlchemy class how its attribute names correspond to database
column names.
Args:
cls: SQLAlchemy ORM class
Returns:
a dictionary mapping attribute names to database column names
"""
attr_col = {} # type: Dict[str, str]
for attrname, column in gen_columns(cls):
attr_col[attrname] = column.name
return attr_col
def colname_to_attrname_dict(cls) -> Dict[str, str]:
return reversedict(attrname_to_colname_dict(cls))
# =============================================================================
# Get relationships from an ORM instance
# =============================================================================
[docs]def gen_relationships(
obj,
) -> Generator[Tuple[str, RelationshipProperty, Type], None, None]:
"""
Yields tuples of ``(attrname, RelationshipProperty, related_class)``
for all relationships of an ORM object.
The object 'obj' can be EITHER an instance OR a class.
"""
insp = inspect(obj) # type: InstanceState
# insp.mapper.relationships is of type
# sqlalchemy.utils._collections.ImmutableProperties, which is basically
# a sort of AttrDict.
for (
attrname,
rel_prop,
) in (
insp.mapper.relationships.items()
): # type: Tuple[str, RelationshipProperty]
# noinspection PyUnresolvedReferences
related_class = rel_prop.mapper.class_
yield attrname, rel_prop, related_class
# =============================================================================
# Inspect ORM objects (SQLAlchemy ORM)
# =============================================================================
[docs]def get_orm_columns(cls: Type) -> List[Column]:
"""
Gets :class:`Column` objects from an SQLAlchemy ORM class.
Does not provide their attribute names.
"""
mapper = inspect(cls) # type: Mapper
# ... returns InstanceState if called with an ORM object
# http://docs.sqlalchemy.org/en/latest/orm/session_state_management.html#session-object-states # noqa: E501
# ... returns Mapper if called with an ORM class
# http://docs.sqlalchemy.org/en/latest/orm/mapping_api.html#sqlalchemy.orm.mapper.Mapper # noqa: E501
colmap = mapper.columns # type: OrderedProperties
return colmap.values()
[docs]def get_orm_column_names(cls: Type, sort: bool = False) -> List[str]:
"""
Gets column names (that is, database column names) from an SQLAlchemy
ORM class.
"""
colnames = [col.name for col in get_orm_columns(cls)]
return sorted(colnames) if sort else colnames
# =============================================================================
# Inspect metadata (SQLAlchemy ORM)
# =============================================================================
[docs]def gen_orm_classes_from_base(base: Type) -> Generator[Type, None, None]:
"""
From an SQLAlchemy ORM base class, yield all the subclasses (except those
that are abstract).
If you begin with the proper :class`Base` class, then this should give all
ORM classes in use.
"""
for cls in gen_all_subclasses(base):
if cls.__dict__.get("__abstract__", False):
# This is SQLAlchemy's own way of detecting abstract classes; see
# sqlalchemy.orm.decl_base
continue # NOT an ORM class
yield cls
[docs]def get_orm_classes_by_table_name_from_base(base: Type) -> Dict[str, Type]:
"""
Given an SQLAlchemy ORM base class, returns a dictionary whose keys are
table names and whose values are ORM classes.
If you begin with the proper :class`Base` class, then this should give all
tables and ORM classes in use.
"""
# noinspection PyUnresolvedReferences
return {cls.__tablename__: cls for cls in gen_orm_classes_from_base(base)}