examples.extending_query.filter_public
优质
小牛编辑
138浏览
2023-12-01
"""Illustrates a global criteria applied to entities of a particular type. The example here is the "public" flag, a simple boolean that indicates the rows are part of a publicly viewable subcategory. Rows that do not include this flag are not shown unless a special option is passed to the query. Uses for this kind of recipe include tables that have "soft deleted" rows marked as "deleted" that should be skipped, rows that have access control rules that should be applied on a per-request basis, etc. """ from sqlalchemy import Boolean from sqlalchemy import Column from sqlalchemy import event from sqlalchemy import orm from sqlalchemy import true from sqlalchemy.orm import Session @event.listens_for(Session, "do_orm_execute") def _add_filtering_criteria(execute_state): """Intercept all ORM queries. Add a with_loader_criteria option to all of them. This option applies to SELECT queries and adds a global WHERE criteria (or as appropriate ON CLAUSE criteria for join targets) to all objects of a certain class or superclass. """ # the with_loader_criteria automatically applies itself to # relationship loads as well including lazy loads. So if this is # a relationship load, assume the option was set up from the top level # query. if ( not execute_state.is_column_load and not execute_state.is_relationship_load and not execute_state.execution_options.get("include_private", False) ): execute_state.statement = execute_state.statement.options( orm.with_loader_criteria( HasPrivate, lambda cls: cls.public == true(), include_aliases=True, ) ) class HasPrivate(object): """Mixin that identifies a class as having private entities""" public = Column(Boolean, nullable=False) if __name__ == "__main__": from sqlalchemy import Integer, Column, String, ForeignKey, Boolean from sqlalchemy import select from sqlalchemy import create_engine from sqlalchemy.orm import relationship, sessionmaker from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() class User(HasPrivate, Base): __tablename__ = "user" id = Column(Integer, primary_key=True) name = Column(String) addresses = relationship("Address", back_populates="user") class Address(HasPrivate, Base): __tablename__ = "address" id = Column(Integer, primary_key=True) email = Column(String) user_id = Column(Integer, ForeignKey("user.id")) user = relationship("User", back_populates="addresses") engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) Session = sessionmaker(bind=engine, future=True) sess = Session() sess.add_all( [ User( name="u1", public=True, addresses=[ Address(email="u1a1", public=True), Address(email="u1a2", public=True), ], ), User( name="u2", public=True, addresses=[ Address(email="u2a1", public=False), Address(email="u2a2", public=True), ], ), User( name="u3", public=False, addresses=[ Address(email="u3a1", public=False), Address(email="u3a2", public=False), ], ), User( name="u4", public=False, addresses=[ Address(email="u4a1", public=False), Address(email="u4a2", public=True), ], ), User( name="u5", public=True, addresses=[ Address(email="u5a1", public=True), Address(email="u5a2", public=False), ], ), ] ) sess.commit() # now querying Address or User objects only gives us the public ones for u1 in sess.query(User).options(orm.selectinload(User.addresses)): assert u1.public # the addresses collection will also be "public only", which works # for all relationship loaders including joinedload for address in u1.addresses: assert address.public # works for columns too cols = ( sess.query(User.id, Address.id) .join(User.addresses) .order_by(User.id, Address.id) .all() ) assert cols == [(1, 1), (1, 2), (2, 4), (5, 9)] cols = ( sess.query(User.id, Address.id) .join(User.addresses) .order_by(User.id, Address.id) .execution_options(include_private=True) .all() ) assert cols == [ (1, 1), (1, 2), (2, 3), (2, 4), (3, 5), (3, 6), (4, 7), (4, 8), (5, 9), (5, 10), ] # count all public addresses assert sess.query(Address).count() == 5 # count all addresses public and private assert ( sess.query(Address).execution_options(include_private=True).count() == 10 ) # load an Address that is public, but its parent User is private # (2.0 style query) a1 = sess.execute(select(Address).filter_by(email="u4a2")).scalar() # assuming the User isn't already in the Session, it returns None assert a1.user is None # however, if that user is present in the session, then a many-to-one # does a simple get() and it will be present sess.expire(a1, ["user"]) u1 = sess.execute( select(User) .filter_by(name="u4") .execution_options(include_private=True) ).scalar() assert a1.user is u1