sqlalchemy 中 scoped_session 与 threading.local() 的关系

关键词:scoped_sessionthreading.local()

下面是一段操作 MySQL 连接的代码,今天来分析下里面 get_dbsession 这个函数牵扯到的一些关于线程的知识。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker
from . import log
try:
from greenlet import getcurrent as get_ident
except ImportError:
try:
from thread import get_ident
except ImportError:
from _thread import get_ident
BaseModel = declarative_base()
dbsession_cache = {}
dbsession_used = {}
logger = log.get_logger('dock.sqldb')
def create_sqldb_engine(url, options={}):
kwargs = {
'pool_size': 20,
'max_overflow': 0,
'pool_recycle': 3600
}
kwargs.update(options or {})
return create_engine(url, **kwargs)
def update_dbsession_used(name):
ident = get_ident()
used = dbsession_used.setdefault(ident, set())
used.add(name)
def get_dbsession(name='default'):
from .globals import config
if name not in dbsession_cache:
conf = config.sqldb.get(name, {})
engine = create_sqldb_engine(conf['url'], conf.get('options', {}))
# threading safe
# scoped_session 接收的参数是一个 sessionmaker ,这个是确保的前提
dbsession = scoped_session(sessionmaker(autocommit=False,
autoflush=False, bind=engine))
dbsession_cache[name] = dbsession
dbsession_class = dbsession_cache[name]
update_dbsession_used(name)
dbsession_class()
return dbsession_class
def clear_dbsession():
global dbsession_used
ident = get_ident()
dbsessions = dbsession_used.pop(ident, [])
for name in dbsessions:
try:
dbsession_class = dbsession_cache.get(name)
dbsession_class.remove()
except:
logger.error('error while clear dbsession')
logger.traceback()

我们看来看这个 get_dbsession 涉及到的一些东西 :

前提: 当我们已经拿到了一个 dbsession_class 对象,并使用对应的 key 存储到了 dbsession_cache 中去。

当我们根据这个 key,拿出了这个 dbsession_class 对象之后,我们看到 dbsession_class 调用了一个 call 方法(括号方法),dbsession_class 是 scoped_session 的一个对象,也就是说,调用了如下代码中的 def __call__(self, **kw): 这个方法,而这个方法最后做的事情是返回了 self.registry() 这个东西,也就是相当于调用了 self.registry 这个属性的 call 方法(括号方法),那我们看到 self.registry 这个属性值是一个对象的实例 ThreadLocalRegistry(session_factory)。 (下一段开始讲这个实例,下方代码是这段话的代码展示)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class scoped_session(object):
"""Provides scoped management of :class:`.Session` objects.
See :ref:`unitofwork_contextual` for a tutorial.
"""
session_factory = None
"""The `session_factory` provided to `__init__` is stored in this
attribute and may be accessed at a later time. This can be useful when
a new non-scoped :class:`.Session` or :class:`.Connection` to the
database is needed."""
def __init__(self, session_factory, scopefunc=None):
"""Construct a new :class:`.scoped_session`.
:param session_factory: a factory to create new :class:`.Session`
instances. This is usually, but not necessarily, an instance
of :class:`.sessionmaker`.
:param scopefunc: optional function which defines
the current scope. If not passed, the :class:`.scoped_session`
object assumes "thread-local" scope, and will use
a Python ``threading.local()`` in order to maintain the current
:class:`.Session`. If passed, the function should return
a hashable token; this token will be used as the key in a
dictionary in order to store and retrieve the current
:class:`.Session`.
"""
self.session_factory = session_factory
if scopefunc:
self.registry = ScopedRegistry(session_factory, scopefunc)
else:
self.registry = ThreadLocalRegistry(session_factory)
def __call__(self, **kw):
r"""Return the current :class:`.Session`, creating it
using the :attr:`.scoped_session.session_factory` if not present.
:param \**kw: Keyword arguments will be passed to the
:attr:`.scoped_session.session_factory` callable, if an existing
:class:`.Session` is not present. If the :class:`.Session` is present
and keyword arguments have been passed,
:exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
"""
if kw:
if self.registry.has():
raise sa_exc.InvalidRequestError(
"Scoped session is already present; "
"no new arguments may be specified."
)
else:
sess = self.session_factory(**kw)
self.registry.set(sess)
return sess
else:
return self.registry()

这个 ThreadLocalRegistry(session_factory) 对象参考如下代码的实现, 发现他同样具有 def __call__(self): 方法,因此,我们上一段里的 self.registry 才能调用这个 call 方法。到这里一切都没有问题(下面代码是这一段的代码展示)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class ThreadLocalRegistry(ScopedRegistry):
"""A :class:`.ScopedRegistry` that uses a ``threading.local()``
variable for storage.
"""
def __init__(self, createfunc):
self.createfunc = createfunc
self.registry = threading.local()
def __call__(self):
try:
return self.registry.value
except AttributeError:
val = self.registry.value = self.createfunc()
return val
def has(self):
return hasattr(self.registry, "value")
def set(self, obj):
self.registry.value = obj
def clear(self):
try:
del self.registry.value
except AttributeError:
pass

当我们开始调用 call 方法的时候,容易让人迷惑的地方出现了,他需要返回这个东西的值 self.registry.value , 但是这个 self.registry 是一个 threading.local() (在最开始的大前提我们已经说过了,这个 dbsession_class 对象已经被初始化过,作为了一个对象保存在了 cache 中,因此在各个阶段,他被初始化的属性都是可以直接拿到的,当然包括这个 threading_local 也是可以直接拿到的属性), 而这个对象存储的值在不同的线程或者协程(如果使用了 gevent , 线程会被协程劫持,就会变成协程)内是不一定一致的,下面的代码会显示这个 threading.local() 对象内数据的存储

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def _patch(self):
key = object.__getattribute__(self, '_local__key')
d = current_thread().__dict__.get(key)
if d is None:
d = {}
current_thread().__dict__[key] = d
object.__setattr__(self, '__dict__', d)
# we have a new instance dict, so call out __init__ if we have
# one
cls = type(self)
if cls.__init__ is not object.__init__:
args, kw = object.__getattribute__(self, '_local__args')
cls.__init__(self, *args, **kw)
else:
object.__setattr__(self, '__dict__', d)
class local(_localbase):
def __getattribute__(self, name):
lock = object.__getattribute__(self, '_local__lock')
lock.acquire()
try:
_patch(self)
return object.__getattribute__(self, name)
finally:
lock.release()
def __setattr__(self, name, value):
if name == '__dict__':
raise AttributeError(
"%r object attribute '__dict__' is read-only"
% self.__class__.__name__)
lock = object.__getattribute__(self, '_local__lock')
lock.acquire()
try:
_patch(self)
return object.__setattr__(self, name, value)
finally:
lock.release()

我们能看到当执行 self.registry.value 时,因为 self.registry 已经是 threading.local() 对象了,因此他的属性的调用应该走的是被复写了的 __getattribute__ 方法,我们从上述代码中看到,他其实是调用了 _patch(self) 函数,这个函数里我们发现每次调用的时候都需要经过一个 current_thread() 的变量,这个变量代表了当前的线程或者协程对象,而不再是我们文章开始时 dbsession_class 这个变量当时的那个线程和协程对象,因此拿不到之前存在那个线程或者协程里存储的变量的值了。因此, 这就解释了为什么我们拿到的是 cache 里的对象,但是还是线程安全的问题。因为,这个对象在使用时如果已经已经进入了一个新的线程或者协程,那么这时从 current_thread() 里获取的属性其实已经是当前线程或者协程的属性了,之前线程或者协程存储的属性不会再这里起作用


  • 再回到我们最开始的那个处理 MySQL 的 session 的代码,这时我们就能很清晰的进行判断了,我们会根据名字提前缓存一个 dbsession 对象,如果这个对象恰巧又在同一个线程或者协程内使用(同一个请求内包含了多次 get_dbsession 方法),那么这时,就相当于直接复用了这个对象。

  • 如果这个对象没有被同一个线程或者协程使用,那么她被放到 cache 中,一旦 get_dbsession 根据同样的名字又拿到了这个对象,从里面取值的时候,由于此时的线程或者协程已经改变了,因此,不会取到之前的那个线程或者携程内存储的信息,做到了线程安全。


threading.local() 是一个全局的变量,可以认为他的形式是如下这样的,这个对象会存储当前正在共存的线程的编号和他对应的值的信息(当然这些值都是赋值上去的),一旦这个线程被销毁,这个字典里对应的 key 和他的值也会被删除,只保留正在共存的键值对。在我们自己的线程或者协程去取对应的值的时候,其实是现根据 current_thread 拿到当前线程的编号或者直接拿到这个线程的对象,然后根据这个线程的线程号,按图索骥,拿到对应的属性的值,因此我们讲在不同的线程中,实例的值会不同。

1
2
3
4
5
{
threading_id1: all_this_threading_attr,
threading_id2: all_this_threading_attr,
...
}