# Copyright 2009-2015 The TxMongo Developers. All rights reserved.
# Use of this source code is governed by the Apache License that can be
# found in the LICENSE file.
from __future__ import absolute_import, division
from bson.codec_options import DEFAULT_CODEC_OPTIONS
from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure
from pymongo.uri_parser import parse_uri
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
from twisted.internet import defer, reactor, task
from twisted.internet.protocol import ReconnectingClientFactory, ClientFactory
from twisted.python import log
from twisted.python.compat import StringType
from txmongo.database import Database
from txmongo.protocol import MongoProtocol, Query
from txmongo.utils import timeout
DEFAULT_MAX_BSON_SIZE = 16777216
DEFAULT_MAX_WRITE_BATCH_SIZE = 1000
_PRIMARY_READ_PREFERENCES = {ReadPreference.PRIMARY.mode, ReadPreference.PRIMARY_PREFERRED.mode}
class _Connection(ReconnectingClientFactory):
__notify_ready = None
__allnodes = None
__index = -1
__uri = None
instance = None
protocol = MongoProtocol
def __init__(self, pool, uri, connection_id, initial_delay, max_delay):
self.__allnodes = list(uri["nodelist"])
self.__notify_ready = []
self.__pool = pool
self.__uri = uri
self.connection_id = connection_id
self.initialDelay = initial_delay
self.maxDelay = max_delay
self.__auth_creds = {}
def buildProtocol(self, addr):
# Build the protocol.
p = ReconnectingClientFactory.buildProtocol(self, addr)
self._initializeProto(p)
return p
@defer.inlineCallbacks
def _initializeProto(self, proto):
yield proto.connectionReady()
self.resetDelay()
uri_options = self.uri['options']
slaveok = uri_options.get('slaveok', False)
if 'readpreference' in uri_options:
slaveok = uri_options['readpreference'] not in _PRIMARY_READ_PREFERENCES
try:
if not slaveok:
# Update our server configuration. This may disconnect if the node
# is not a master.
yield self.configure(proto)
yield self._auth_proto(proto)
self.setInstance(instance=proto)
except Exception as e:
proto.fail(e)
@staticmethod
@timeout
def __send_ismaster(proto, **kwargs):
query = Query(collection="admin.$cmd", query={"ismaster": 1})
return proto.send_QUERY(query)
@defer.inlineCallbacks
def configure(self, proto):
"""
Configures the protocol using the information gathered from the
remote Mongo instance. Such information may contain the max
BSON document size, replica set configuration, and the master
status of the instance.
"""
if not proto:
defer.returnValue(None)
reply = yield self.__send_ismaster(proto, timeout=self.initialDelay)
# Handle the reply from the "ismaster" query. The reply contains
# configuration information about the peer.
# Make sure we got a result document.
if len(reply.documents) != 1:
raise OperationFailure("TxMongo: invalid document length.")
# Get the configuration document from the reply.
config = reply.documents[0].decode()
# Make sure the command was successful.
if not config.get("ok"):
code = config.get("code")
msg = "TxMongo: " + config.get("err", "Unknown error")
raise OperationFailure(msg, code)
# Check that the replicaSet matches.
set_name = config.get("setName")
expected_set_name = self.uri["options"].get("replicaset")
if expected_set_name and (expected_set_name != set_name):
# Log the invalid replica set failure.
msg = "TxMongo: Mongo instance does not match requested replicaSet."
raise ConfigurationError(msg)
# Track max bson object size limit.
proto.max_bson_size = config.get("maxBsonObjectSize", DEFAULT_MAX_BSON_SIZE)
proto.max_write_batch_size = config.get("maxWriteBatchSize", DEFAULT_MAX_WRITE_BATCH_SIZE)
proto.set_wire_versions(config.get("minWireVersion", 0),
config.get("maxWireVersion", 0))
# Track the other hosts in the replica set.
hosts = config.get("hosts")
if isinstance(hosts, list) and hosts:
for host in hosts:
if ':' not in host:
host = (host, 27017)
else:
host = host.split(':', 1)
host[1] = int(host[1])
host = tuple(host)
if host not in self.__allnodes:
self.__allnodes.append(host)
# Check if this node is the master.
ismaster = config.get("ismaster")
if not ismaster:
msg = "TxMongo: MongoDB host `%s` is not master." % config.get('me')
raise AutoReconnect(msg)
def clientConnectionFailed(self, connector, reason):
self.instance = None
if self.continueTrying:
self.connector = connector
self.retryNextHost()
def clientConnectionLost(self, connector, reason):
self.instance = None
if self.continueTrying:
self.connector = connector
self.retryNextHost()
def notifyReady(self):
"""
Returns a deferred that will fire when the factory has created a
protocol that can be used to communicate with a Mongo server.
Note that this will not fire until we have connected to a Mongo
master, unless slaveOk was specified in the Mongo URI connection
options.
"""
if self.instance:
return defer.succeed(self.instance)
def on_cancel(d):
self.__notify_ready.remove(d)
df = defer.Deferred(on_cancel)
self.__notify_ready.append(df)
return df
def retryNextHost(self, connector=None):
"""
Have this connector connect again, to the next host in the
configured list of hosts.
"""
if not self.continueTrying:
msg = "TxMongo: Abandoning {0} on explicit request.".format(connector)
log.msg(msg)
return
if connector is None:
if self.connector is None:
raise ValueError("TxMongo: No additional connector to retry.")
else:
connector = self.connector
delay = False
self.__index += 1
if self.__index >= len(self.__allnodes):
self.__index = 0
delay = True
connector.host, connector.port = self.__allnodes[self.__index]
if delay:
self.retry(connector)
else:
connector.connect()
def setInstance(self, instance=None, reason=None):
if instance == self.instance:
# Should not fail deferreds from __notify_ready if setInstance(None)
# called when instance is already None
return
self.instance = instance
deferreds, self.__notify_ready = self.__notify_ready, []
if deferreds:
for df in deferreds:
if instance:
df.callback(self)
else:
df.errback(reason)
@property
def uri(self):
return self.__uri
def _auth_proto(self, proto):
return defer.DeferredList(
[proto.authenticate(database, username, password, mechanism)
for database, (username, password, mechanism) in self.__auth_creds.items()],
consumeErrors=True
)
def authenticate(self, database, username, password, mechanism):
self.__auth_creds[str(database)] = (username, password, mechanism)
if self.instance:
return self.instance.authenticate(database, username, password, mechanism)
else:
return defer.succeed(None)
[docs]class ConnectionPool(object):
__index = 0
__pool = None
__pool_size = None
__uri = None
__wc_possible_options = {'w', "wtimeout", 'j', "fsync"}
__pinger_discovery_interval = 10
def __init__(self, uri="mongodb://127.0.0.1:27017", pool_size=1, ssl_context_factory=None,
ping_interval=10, ping_timeout=10, **kwargs):
assert isinstance(uri, StringType)
assert isinstance(pool_size, int)
assert pool_size >= 1
if not uri.startswith("mongodb://"):
uri = "mongodb://" + uri
self.__uri = parse_uri(uri)
wc_options = self.__uri['options'].copy()
wc_options.update(kwargs)
wc_options = dict((k, v) for k, v in wc_options.items() if k in self.__wc_possible_options)
self.__write_concern = WriteConcern(**wc_options)
self.__codec_options = kwargs.get('codec_options', DEFAULT_CODEC_OPTIONS)
retry_delay = kwargs.get('retry_delay', 1.0)
max_delay = kwargs.get('max_delay', 60.0)
self.__pool_size = pool_size
self.__pool = [
_Connection(self, self.__uri, i, retry_delay, max_delay)
for i in range(pool_size)
]
if self.__uri['database'] and self.__uri['username'] and self.__uri['password']:
auth_db = self.__uri['options'].get('authsource') or self.__uri['database']
self.authenticate(auth_db, self.__uri['username'],
self.__uri['password'],
self.__uri['options'].get('authmechanism', 'DEFAULT'))
host, port = self.__uri['nodelist'][0]
self.ssl_context_factory = ssl_context_factory
initial_delay = kwargs.get('retry_delay', 30)
for factory in self.__pool:
factory.connector = self.__tcp_or_ssl_connect(host, port, factory,
timeout=initial_delay)
self.ping_interval = ping_interval
self.ping_timeout = ping_timeout
self.__pingers = {}
self.__pinger_discovery = task.LoopingCall(self.__discovery_nodes_to_ping)
self.__pinger_discovery.start(self.__pinger_discovery_interval, now=False)
def __tcp_or_ssl_connect(self, host, port, factory, **kwargs):
if self.ssl_context_factory:
return reactor.connectSSL(host, port, factory, self.ssl_context_factory, **kwargs)
else:
return reactor.connectTCP(host, port, factory, **kwargs)
@property
def write_concern(self):
return self.__write_concern
@property
def codec_options(self):
return self.__codec_options
[docs] def getprotocols(self):
return self.__pool
def __getitem__(self, name):
return Database(self, name)
def __getattr__(self, name):
return self[name]
def __repr__(self):
if self.uri["nodelist"]:
return "Connection(%r, %r)" % self.uri["nodelist"][0]
return "Connection()"
[docs] def get_default_database(self):
if self.uri["database"]:
return self[self.uri["database"]]
else:
return None
[docs] def drop_database(self, name_or_database):
if isinstance(name_or_database, (bytes, StringType)):
db = self[name_or_database]
elif isinstance(name_or_database, Database):
db = name_or_database
else:
raise TypeError("argument to drop_database() should be database name "
"or database object")
return db.command("dropDatabase")
[docs] def disconnect(self):
self.__pinger_discovery.stop()
for pinger in self.__pingers.values():
pinger.connector.disconnect()
for factory in self.__pool:
factory.stopTrying()
factory.stopFactory()
if factory.instance and factory.instance.transport:
factory.instance.transport.loseConnection()
if factory.connector:
factory.connector.disconnect()
# Wait for the next iteration of the loop for resolvers
# to potentially cleanup.
df = defer.Deferred()
reactor.callLater(0, df.callback, None)
return df
[docs] def authenticate(self, database, username, password, mechanism="DEFAULT"):
def fail(failure):
failure.trap(defer.FirstError)
raise failure.value.subFailure.value
return defer.gatherResults(
[connection.authenticate(database, username, password, mechanism)
for connection in self.__pool],
consumeErrors=True
).addErrback(fail)
[docs] def getprotocol(self):
# Get the next protocol available for communication in the pool.
connection = self.__pool[self.__index]
self.__index = (self.__index + 1) % self.__pool_size
# If the connection is already connected, just return it.
if connection.instance:
return defer.succeed(connection.instance)
# Wait for the connection to connection.
return connection.notifyReady().addCallback(lambda conn: conn.instance)
@property
def uri(self):
return self.__uri
# Pingers are persistent connections that are established to each
# node of the replicaset to monitor their availability.
#
# Every `__pinger_discovery_interval` seconds ConnectionPool compares
# actual nodes addresses and starts/stops Pingers to ensure that
# Pinger is started for every node address.
#
# Every `ping_interval` seconds pingers send ismaster commands.
#
# All pool connections to corresponding TCP address are dropped
# if one of following happens:
# 1. Pinger is unable to receive response to ismaster within
# `ping_timeout` seconds
# 2. Pinger is unable to connect to address within `ping_timeout`
# seconds
#
# If Pinger's connection is closed by server, pool connections are not
# dropped. Next discovery procedure will recreate the Pinger.
def __discovery_nodes_to_ping(self):
existing = set(self.__pingers)
peers = {conn.instance.transport.getPeer()
for conn in self.__pool if conn.instance}
for peer in peers - existing:
pinger = _Pinger(self.ping_interval, self.ping_timeout,
self.__on_ping_lost, self.__on_ping_fail)
pinger.connector = self.__tcp_or_ssl_connect(peer.host, peer.port, pinger,
timeout=self.ping_timeout)
self.__pingers[peer] = pinger
for unused_peer in existing - peers:
self.__pingers[unused_peer].connector.disconnect()
del self.__pingers[unused_peer]
def __on_ping_lost(self, addr):
if addr in self.__pingers:
self.__pingers[addr].connector.disconnect()
del self.__pingers[addr]
def __on_ping_fail(self, addr):
# Kill all pool connections to this addr
for connection in self.__pool:
if connection.instance and connection.instance.transport.getPeer() == addr:
connection.instance.transport.abortConnection()
self.__on_ping_lost(addr)
class _PingerProtocol(MongoProtocol):
__next_call = None
def __init__(self, interval, timeout, fail_callback):
MongoProtocol.__init__(self)
self.interval = interval
self.timeout = timeout
self.fail_callback = fail_callback
def ping(self):
def on_ok(result):
if timeout_call.active():
timeout_call.cancel()
self.__next_call = reactor.callLater(self.interval, self.ping)
def on_fail(failure):
if timeout_call.active():
timeout_call.cancel()
on_timeout()
def on_timeout():
self.transport.loseConnection()
self.fail_callback(self.transport.getPeer())
timeout_call = reactor.callLater(self.timeout, on_timeout)
self.send_QUERY(Query(collection="admin.$cmd", query={"ismaster": 1}))\
.addCallbacks(on_ok, on_fail)
def connectionMade(self):
MongoProtocol.connectionMade(self)
self.ping()
def connectionLost(self, reason):
MongoProtocol.connectionLost(self, reason)
if self.__next_call and self.__next_call.active():
self.__next_call.cancel()
class _Pinger(ClientFactory):
def __init__(self, interval, timeout, lost_callback, fail_callback):
self.interval = interval
self.timeout = timeout
self.lost_callback = lost_callback
self.fail_callback = fail_callback
def buildProtocol(self, addr):
proto = _PingerProtocol(self.interval, self.timeout, self.fail_callback)
proto.factory = self
return proto
def setInstance(self, instance=None, reason=None):
pass
def clientConnectionLost(self, connector, reason):
self.lost_callback(connector.getDestination())
def clientConnectionFailed(self, connector, reason):
self.fail_callback(connector.getDestination())
###
# Begin Legacy Wrapper
###
[docs]class MongoConnection(ConnectionPool):
def __init__(self, host="127.0.0.1", port=27017, pool_size=1, **kwargs):
uri = "mongodb://%s:%d/" % (host, port)
ConnectionPool.__init__(self, uri, pool_size=pool_size, **kwargs)
lazyMongoConnectionPool = MongoConnection
lazyMongoConnection = MongoConnection
MongoConnectionPool = MongoConnection
###
# End Legacy Wrapper
###