#!/usr/bin/env python3
#-*- coding: utf-8 -*-#
# Python3 ORM hacking
# 说明:
# 之前分析了一个Python2 ORM的源代码,这次分析一个Python3的源代码,在写法上
# 还是又挺大的区别的。
# 2016-10-22 深圳 南山平山村 曾剑锋
#
# 源码:
# https://github.com/michaelliao/awesome-python3-webapp/tree/day-03
#
# 参考文章:
#1. python logging模块使用教程
# http://www.jianshu.com/p/feb86c06c4f4
# 2. Python async/await入门
# https://ipfans.github.io/2015/08/introduction-to-async-and-await/
# 3. 浅析python的metaclass
# http://jianpx.iteye.com/blog/908121
# 4. Why I got ignored exception when I use aiomysql in python 3.5 #59# https://github.com/aio-libs/aiomysql/issues/59
#
__author__= 'Michael Liao'import asyncio, logging
import aiomysql
# SQL日志打印输出模板
def log(sql, args=()):
logging.info('SQL: %s' %sql)
# 创建数据库连接池
async def create_pool(loop,**kw):
logging.info('create database connection pool...')
# 标记__pool为文件内全局变量,在其他函数内可以直接访问global__pool
__pool=await aiomysql.create_pool(
host=kw.get('host', 'localhost'),
port=kw.get('port', 3306),
user=kw['user'],
password=kw['password'],
db=kw['db'],
charset=kw.get('charset', 'utf8'),
autocommit=kw.get('autocommit', True),
maxsize=kw.get('maxsize', 10),
minsize=kw.get('minsize', 1),
loop=loop
)
# 数据库查询
async defselect(sql, args, size=None):
# 输出SQL日志信息
log(sql, args)global__pool
# 从连接池中获取连接,aysnc是异步获取连接
async with __pool.get() asconn:
async with conn.cursor(aiomysql.DictCursor)ascur:
# 合成实际的SQL
await cur.execute(sql.replace('?', '%s'), args or ())
# 根据size来获取数据多少行记录ifsize:
rs=await cur.fetchmany(size)else:
rs=await cur.fetchall()
# 给出获取到的信息条数
logging.info('rows returned: %s' %len(rs))returnrs
# 数据库直接执行SQL
async def execute(sql, args, autocommit=True):
log(sql)
async with __pool.get() asconn:
# 如果不是自动提交ifnot autocommit:
await conn.begin()try:
async with conn.cursor(aiomysql.DictCursor)ascur:
await cur.execute(sql.replace('?', '%s'), args)
# 返回的执行SQL后有效行数,从代码上可以看出,这部分主要是执行更新、插入、删除等SQL语句
affected=cur.rowcount
# 完成提交工作ifnot autocommit:
await conn.commit()
except BaseExceptionase:
# 出现问题,回滚ifnot autocommit:
await conn.rollback()
raise # 直接再次抛出异常returnaffected # 返回有效行数
# 合成可替代参数字符串,先使用'?'代替'%s'def create_args_string(num):
L=[]for n inrange(num):
L.append('?')return ','.join(L)
# 对应数据库中每一个字段的一个域的基类class Field(object):
# 域名、域类型、是否是主键、默认值
def __init__(self, name, column_type, primary_key,default):
self.name=name
self.column_type=column_type
self.primary_key=primary_key
self.default = default# 重写默认输出的str函数
def __str__(self):return '' %(self.__class__.__name__, self.column_type, self.name)
# 字符串类型的域classStringField(Field):
def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'):
super().__init__(name, ddl, primary_key,default)
# Boolean类型的域classBooleanField(Field):
def __init__(self, name=None, default=False):
super().__init__(name,'boolean', False, default)
# 整形类型的域classIntegerField(Field):
def __init__(self, name=None, primary_key=False, default=0):
super().__init__(name,'bigint', primary_key, default)
# 浮点类型的域classFloatField(Field):
def __init__(self, name=None, primary_key=False, default=0.0):
super().__init__(name,'real', primary_key, default)
# 文本类型的域classTextField(Field):
def __init__(self, name=None, default=None):
super().__init__(name,'text', False, default)
# MVC中的Model的元类,主要用于自动生成映射(map)类classModelMetaclass(type):
# name: 类的名字
# bases: 基类,通常是tuple类型
# attrs: dict类型,就是类的属性或者函数
def __new__(cls, name, bases, attrs):
# 过滤掉Model类直接生成的实例类if name=='Model':returntype.__new__(cls, name, bases, attrs)
# 从类的属性中获取__table__,其实也就是于数据库对应的表名,如果不存在那么就是等于类名
tableName= attrs.get('__table__', None) or name
logging.info('found model: %s (table: %s)' %(name, tableName))
# 创建映射字典
mappings=dict()
# 域list
fields=[]
# 主键标记
primaryKey=None
# 获取类中的所有的键值对for k, v inattrs.items():
# 选择Field类型实例的属性作为映射键值ifisinstance(v, Field):
logging.info('found mapping: %s ==> %s' %(k, v))
# 将当前的键值对放入mapping中
mappings[k]=vifv.primary_key:
# 防止出现两个、两个以上的主键ifprimaryKey:
raise StandardError('Duplicate primary key for field: %s' %k)
primaryKey=kelse:
# 将key添加进入fields中,也就是映射类中的属性和数据库中的表的域, 这里面不包含主键
fields.append(k)
# 前面可能没有找到主键,提示一下ifnot primaryKey:
raise StandardError('Primary key not found.')
# 删除这些类属性 防止访问实例属性的时候发生错误,因为实例属性优先级大于类属性for k inmappings.keys():
attrs.pop(k)
escaped_fields= list(map(lambda f: '`%s`' %f, fields))
attrs['__mappings__'] =mappings # 保存属性和列的映射关系
attrs['__table__'] =tableName # 表名
attrs['__primary_key__'] =primaryKey # 主键属性名
attrs['__fields__'] =fields # 除主键外的属性名
# 生成查询SQL
attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ','.join(escaped_fields), tableName)
# 生成插入SQL
attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ','.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))
# 生成更新SQL
attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ','.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)
# 生成删除SQL
attrs['__delete__'] = 'delete from `%s` where `%s`=?' %(tableName, primaryKey)
# 调用type生成类returntype.__new__(cls, name, bases, attrs)
# 继承自ModelMetaclass元类、dict的类class Model(dict, metaclass=ModelMetaclass):
def __init__(self,**kw):
super(Model, self).__init__(**kw)
# 重写get方法
def __getattr__(self, key):try:returnself[key]
except KeyError:
raise AttributeError(r"'Model' object has no attribute '%s'" %key)
# 重写set方法
def __setattr__(self, key, value):
self[key]=value
# 重写get方法
def getValue(self, key):returngetattr(self, key, None)
# 获取值,当不存在的时候获取的是默认值
def getValueOrDefault(self, key):
value=getattr(self, key, None)if value isNone:
field=self.__mappings__[key]if field.default isnot None:
value= field.default() if callable(field.default) else field.defaultlogging.debug('using default value for %s: %s' %(key, str(value)))
setattr(self, key, value)returnvalue
@classmethod
async def findAll(cls,where=None, args=None, **kw):'find objects by where clause.'# 获取元类自动生成的SQL语句,并根据当前的参数,继续合成
sql=[cls.__select__]if where:
sql.append('where')
sql.append(where)if args isNone:
args=[]
orderBy= kw.get('orderBy', None)iforderBy:
sql.append('order by')
sql.append(orderBy)
limit= kw.get('limit', None)if limit isnot None:
sql.append('limit')if isinstance(limit, int):
sql.append('?')
args.append(limit)
elif isinstance(limit, tuple) and len(limit)== 2:
sql.append('?, ?')
args.extend(limit)else:
raise ValueError('Invalid limit value: %s' %str(limit))
# 直接调用select函数来处理, 这里是等待函数执行完成函数才能返回
rs= await select(' '.join(sql), args)
# 该类本身是字典,自己用自己生成新的实例,里面的阈值正好也是需要查询return [cls(**r) for r inrs]
@classmethod
async def findNumber(cls, selectField,where=None, args=None):'find number by select and where.'# 这里的 _num_ 什么意思?别名? 我估计是mysql里面一个记录实时查询结果条数的变量
sql= ['select %s _num_ from `%s`' %(selectField, cls.__table__)]if where:
sql.append('where')
sql.append(where)
rs= await select(' '.join(sql), args, 1)if len(rs) == 0:returnNonereturn rs[0]['_num_']
@classmethod
async def find(cls, pk):'find object by primary key.'# 通过主键查找对象, 如果不存在,那么就返回None
rs= await select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1)if len(rs) == 0:returnNonereturn cls(**rs[0])
# 插入语句对应的方法
async def save(self):
args=list(map(self.getValueOrDefault, self.__fields__))
args.append(self.getValueOrDefault(self.__primary_key__))
rows=await execute(self.__insert__, args)if rows != 1:
logging.warn('failed to insert record: affected rows: %s' %rows)
# 更新语句对应的方法
async def update(self):
args=list(map(self.getValue, self.__fields__))
args.append(self.getValue(self.__primary_key__))
rows=await execute(self.__update__, args)if rows != 1:
logging.warn('failed to update by primary key: affected rows: %s' %rows)
# 删除语句对应的方法
async def remove(self):
args=[self.getValue(self.__primary_key__)]
rows=await execute(self.__delete__, args)if rows != 1:
logging.warn('failed to remove by primary key: affected rows: %s' % rows)