# Copyright 2009-2015 The TxMongo Developers. All rights reserved.
# Use of this source code is governed by the Apache License that can be
# found in the LICENSE file.
from __future__ import absolute_import, division
import bson
from bson import BSON, ObjectId
from bson.code import Code
from bson.son import SON
from pymongo.errors import InvalidName
from pymongo.helpers import _check_write_command_response
from pymongo.results import InsertOneResult, InsertManyResult, UpdateResult, \
DeleteResult
from pymongo.common import validate_ok_for_update, validate_ok_for_replace, \
validate_is_mapping, validate_boolean
from pymongo.collection import ReturnDocument
from pymongo.write_concern import WriteConcern
from txmongo.protocol import DELETE_SINGLE_REMOVE, UPDATE_UPSERT, UPDATE_MULTI, \
Query, Getmore, Insert, Update, Delete, KillCursors, INSERT_CONTINUE_ON_ERROR
from txmongo import filter as qf
from twisted.internet import defer
from twisted.python.compat import unicode, comparable
@comparable
[docs]class Collection(object):
def __init__(self, database, name, write_concern=None):
if not isinstance(name, (bytes, unicode)):
raise TypeError("name must be an instance of (bytes, unicode)")
if not name or ".." in name:
raise InvalidName("collection names cannot be empty")
if "$" in name and not (name.startswith("oplog.$main") or
name.startswith("$cmd")):
raise InvalidName("collection names must not contain '$': %r" % name)
if name[0] == "." or name[-1] == ".":
raise InvalidName("collection names must not start or end with '.': %r" % name)
if "\x00" in name:
raise InvalidName("collection names must not contain the null character")
self._database = database
self._collection_name = unicode(name)
self.__write_concern = write_concern
def __str__(self):
return "%s.%s" % (str(self._database), self._collection_name)
def __repr__(self):
return "Collection(%s, %s)" % (self._database, self._collection_name)
@property
def full_name(self):
return '{0}.{1}'.format(str(self._database), self._collection_name)
@property
def name(self):
return self._collection_name
@property
def database(self):
return self._database
def __getitem__(self, collection_name):
return Collection(self._database,
"%s.%s" % (self._collection_name, collection_name))
def __cmp__(self, other):
if isinstance(other, Collection):
def cmp(a, b):
return (a > b) - (a < b)
return cmp((self._database, self._collection_name),
(other._database, other._collection_name))
return NotImplemented
def __getattr__(self, collection_name):
return self[collection_name]
def __call__(self, collection_name):
return self[collection_name]
@property
def write_concern(self):
return self.__write_concern or self._database.write_concern
[docs] def with_options(self, **kwargs):
"""Get a clone of collection changing the specified settings."""
# PyMongo's method gets several positional arguments. We support
# only write_concern for now which is the 3rd positional argument.
# So we are using **kwargs here to force user's code to specify
# write_concern as named argument, so adding other args in future
# won't break compatibility
write_concern = kwargs.get("write_concern") or self.__write_concern
return Collection(self._database, self._collection_name,
write_concern=write_concern)
@staticmethod
def _normalize_fields_projection(fields):
"""
transform a list of fields from ["a", "b"] to {"a":1, "b":1}
"""
if fields is None:
return None
if isinstance(fields, dict):
return fields
# Consider fields as iterable
as_dict = {}
for field in fields:
if not isinstance(field, (bytes, unicode)):
raise TypeError("fields must be a list of key names")
as_dict[field] = 1
if not as_dict:
# Empty list should be treated as "_id only"
as_dict = {"_id": 1}
return as_dict
@staticmethod
def _gen_index_name(keys):
return u'_'.join([u"%s_%s" % item for item in keys])
@defer.inlineCallbacks
[docs] def options(self):
result = yield self._database.system.namespaces.find_one({"name": str(self)})
if not result:
result = {}
options = result.get("options", {})
if "create" in options:
del options["create"]
defer.returnValue(options)
@defer.inlineCallbacks
[docs] def find(self, spec=None, skip=0, limit=0, fields=None, filter=None, cursor=False, **kwargs):
docs, dfr = yield self.find_with_cursor(spec=spec, skip=skip, limit=limit,
fields=fields, filter=filter, **kwargs)
if cursor:
defer.returnValue((docs, dfr))
result = []
while docs:
result.extend(docs)
docs, dfr = yield dfr
defer.returnValue(result)
def __apply_find_filter(self, spec, filter):
if filter:
if "query" not in spec:
spec = {"$query": spec}
for k, v in filter.items():
if isinstance(v, (list, tuple)):
spec['$' + k] = dict(v)
else:
spec['$' + k] = v
return spec
[docs] def find_with_cursor(self, spec=None, skip=0, limit=0, fields=None, filter=None, **kwargs):
""" find method that uses the cursor to only return a block of
results at a time.
Arguments are the same as with find()
returns deferred that results in a tuple: (docs, deferred) where
docs are the current page of results and deferred results in the next
tuple. When the cursor is exhausted, it will return the tuple
([], None)
"""
if spec is None:
spec = SON()
if not isinstance(spec, dict):
raise TypeError("spec must be an instance of dict")
if not isinstance(fields, (dict, list)) and fields is not None:
raise TypeError("fields must be an instance of dict or list")
if not isinstance(skip, int):
raise TypeError("skip must be an instance of int")
if not isinstance(limit, int):
raise TypeError("limit must be an instance of int")
fields = self._normalize_fields_projection(fields)
spec = self.__apply_find_filter(spec, filter)
as_class = kwargs.get("as_class", dict)
deferred_protocol = self._database.connection.getprotocol()
def after_connection(proto):
flags = kwargs.get("flags", 0)
query = Query(flags=flags, collection=str(self),
n_to_skip=skip, n_to_return=limit,
query=spec, fields=fields)
deferred_query = proto.send_QUERY(query)
deferred_query.addCallback(after_reply, proto)
return deferred_query
def after_reply(reply, proto, fetched=0):
documents = reply.documents
docs_count = len(documents)
if limit > 0:
docs_count = min(docs_count, limit - fetched)
fetched += docs_count
options = bson.codec_options.CodecOptions(document_class=as_class)
out = [document.decode(codec_options=options) for document in documents[:docs_count]]
if reply.cursor_id:
if limit == 0:
to_fetch = 0 # no limit
elif limit < 0:
# We won't actually get here because MongoDB won't
# create cursor when limit < 0
to_fetch = None
else:
to_fetch = limit - fetched
if to_fetch <= 0:
to_fetch = None # close cursor
if to_fetch is None:
proto.send_KILL_CURSORS(KillCursors(cursors=[reply.cursor_id]))
return out, defer.succeed(([], None))
next_reply = proto.send_GETMORE(Getmore(
collection=str(self), cursor_id=reply.cursor_id,
n_to_return=to_fetch
))
next_reply.addCallback(after_reply, proto, fetched)
return out, next_reply
return out, defer.succeed(([], None))
deferred_protocol.addCallback(after_connection)
return deferred_protocol
@defer.inlineCallbacks
[docs] def find_one(self, spec=None, fields=None, **kwargs):
if isinstance(spec, ObjectId):
spec = {"_id": spec}
result = yield self.find(spec=spec, limit=1, fields=fields, **kwargs)
defer.returnValue(result[0] if result else None)
@defer.inlineCallbacks
[docs] def count(self, spec=None, fields=None):
fields = self._normalize_fields_projection(fields)
result = yield self._database.command("count", self._collection_name,
query=spec or SON(),
fields=fields)
defer.returnValue(result["n"])
[docs] def group(self, keys, initial, reduce, condition=None, finalize=None):
body = {
"ns": self._collection_name,
"initial": initial,
"$reduce": Code(reduce),
}
if isinstance(keys, (bytes, unicode)):
body["$keyf"] = Code(keys)
else:
body["key"] = self._normalize_fields_projection(keys)
if condition:
body["cond"] = condition
if finalize:
body["finalize"] = Code(finalize)
return self._database.command("group", body)
@defer.inlineCallbacks
[docs] def filemd5(self, spec):
if not isinstance(spec, ObjectId):
raise ValueError("filemd5 expected an objectid for its "
"non-keyword argument")
result = yield self._database.command("filemd5", spec, root=self._collection_name)
defer.returnValue(result.get("md5"))
def _get_write_concern(self, safe=None, **wc_options):
from_opts = WriteConcern(**wc_options)
if from_opts.document:
return from_opts
if safe == True:
if self.write_concern.acknowledged:
return self.write_concern
else:
# Edge case: MongoConnection(w=0).db.coll.insert(..., safe=True)
# In this case safe=True must issue getLastError without args
# even if connection-level write concern was unacknowledged
return WriteConcern()
elif safe == False:
return WriteConcern(w=0)
return self.write_concern
@defer.inlineCallbacks
[docs] def insert(self, docs, safe=None, flags=0, **kwargs):
if isinstance(docs, dict):
ids = docs.get("_id", ObjectId())
docs["_id"] = ids
docs = [docs]
elif isinstance(docs, list):
ids = []
for doc in docs:
if isinstance(doc, dict):
oid = doc.get("_id", ObjectId())
ids.append(oid)
doc["_id"] = oid
else:
raise TypeError("insert takes a document or a list of documents")
else:
raise TypeError("insert takes a document or a list of documents")
docs = [BSON.encode(d) for d in docs]
insert = Insert(flags=flags, collection=str(self), documents=docs)
proto = yield self._database.connection.getprotocol()
proto.send_INSERT(insert)
write_concern = self._get_write_concern(safe, **kwargs)
if write_concern.acknowledged:
yield proto.get_last_error(str(self._database), **write_concern.document)
defer.returnValue(ids)
@defer.inlineCallbacks
def _insert_one_or_many(self, documents, ordered=True):
if self.write_concern.acknowledged:
inserted_ids = []
for doc in documents:
if "_id" not in doc:
doc["_id"] = ObjectId()
inserted_ids.append(doc["_id"])
command = SON([("insert", self._collection_name),
("documents", documents),
("ordered", ordered),
("writeConcern", self.write_concern.document)])
response = yield self._database.command(command)
_check_write_command_response([[0, response]])
else:
# falling back to OP_INSERT in case of unacknowledged op
flags = INSERT_CONTINUE_ON_ERROR if not ordered else 0
inserted_ids = yield self.insert(documents, flags=flags)
defer.returnValue(inserted_ids)
@defer.inlineCallbacks
[docs] def insert_one(self, document):
inserted_ids = yield self._insert_one_or_many([document])
defer.returnValue(InsertOneResult(inserted_ids[0], self.write_concern.acknowledged))
@defer.inlineCallbacks
[docs] def insert_many(self, documents, ordered=True):
inserted_ids = yield self._insert_one_or_many(documents, ordered)
defer.returnValue(InsertManyResult(inserted_ids, self.write_concern.acknowledged))
@defer.inlineCallbacks
[docs] def update(self, spec, document, upsert=False, multi=False, safe=None, flags=0, **kwargs):
if not isinstance(spec, dict):
raise TypeError("spec must be an instance of dict")
if not isinstance(document, dict):
raise TypeError("document must be an instance of dict")
if not isinstance(upsert, bool):
raise TypeError("upsert must be an instance of bool")
if multi:
flags |= UPDATE_MULTI
if upsert:
flags |= UPDATE_UPSERT
spec = BSON.encode(spec)
document = BSON.encode(document)
update = Update(flags=flags, collection=str(self),
selector=spec, update=document)
proto = yield self._database.connection.getprotocol()
proto.send_UPDATE(update)
write_concern = self._get_write_concern(safe, **kwargs)
if write_concern.acknowledged:
ret = yield proto.get_last_error(str(self._database), **write_concern.document)
defer.returnValue(ret)
@defer.inlineCallbacks
def _update(self, filter, update, upsert, multi):
validate_is_mapping("filter", filter)
validate_boolean("upsert", upsert)
if self.write_concern.acknowledged:
updates = [SON([('q', filter), ('u', update),
("upsert", upsert), ("multi", multi)])]
command = SON([("update", self._collection_name),
("updates", updates),
("writeConcern", self.write_concern.document)])
raw_response = yield self._database.command(command)
_check_write_command_response([[0, raw_response]])
# Extract upserted_id from returned array
if raw_response.get("upserted"):
raw_response["upserted"] = raw_response["upserted"][0]["_id"]
else:
yield self.update(filter, update, upsert=upsert, multi=multi)
raw_response = None
defer.returnValue(raw_response)
@defer.inlineCallbacks
[docs] def update_one(self, filter, update, upsert=False):
validate_ok_for_update(update)
raw_response = yield self._update(filter, update, upsert, multi=False)
defer.returnValue(UpdateResult(raw_response, self.write_concern.acknowledged))
@defer.inlineCallbacks
[docs] def update_many(self, filter, update, upsert=False):
validate_ok_for_update(update)
raw_response = yield self._update(filter, update, upsert, multi=True)
defer.returnValue(UpdateResult(raw_response, self.write_concern.acknowledged))
@defer.inlineCallbacks
[docs] def replace_one(self, filter, replacement, upsert=False):
validate_ok_for_replace(replacement)
raw_response = yield self._update(filter, replacement, upsert, multi=False)
defer.returnValue(UpdateResult(raw_response, self.write_concern.acknowledged))
[docs] def save(self, doc, safe=None, **kwargs):
if not isinstance(doc, dict):
raise TypeError("cannot save objects of type %s" % type(doc))
oid = doc.get("_id")
if oid:
return self.update({"_id": oid}, doc, safe=safe, upsert=True, **kwargs)
else:
return self.insert(doc, safe=safe, **kwargs)
@defer.inlineCallbacks
[docs] def remove(self, spec, safe=None, single=False, flags=0, **kwargs):
if isinstance(spec, ObjectId):
spec = SON(dict(_id=spec))
if not isinstance(spec, dict):
raise TypeError("spec must be an instance of dict, not %s" % type(spec))
if single:
flags |= DELETE_SINGLE_REMOVE
spec = BSON.encode(spec)
delete = Delete(flags=flags, collection=str(self), selector=spec)
proto = yield self._database.connection.getprotocol()
proto.send_DELETE(delete)
write_concern = self._get_write_concern(safe, **kwargs)
if write_concern.acknowledged:
ret = yield proto.get_last_error(str(self._database), **write_concern.document)
defer.returnValue(ret)
@defer.inlineCallbacks
def _delete(self, filter, multi):
validate_is_mapping("filter", filter)
if self.write_concern.acknowledged:
deletes = [SON([('q', filter), ("limit", 0 if multi else 1)])]
command = SON([("delete", self._collection_name),
("deletes", deletes),
("writeConcern", self.write_concern.document)])
raw_response = yield self._database.command(command)
_check_write_command_response([[0, raw_response]])
else:
yield self.remove(filter, single=not multi)
raw_response = None
defer.returnValue(raw_response)
@defer.inlineCallbacks
[docs] def delete_one(self, filter):
raw_response = yield self._delete(filter, multi=False)
defer.returnValue(DeleteResult(raw_response, self.write_concern.acknowledged))
@defer.inlineCallbacks
[docs] def delete_many(self, filter):
raw_response = yield self._delete(filter, multi=True)
defer.returnValue(DeleteResult(raw_response, self.write_concern.acknowledged))
[docs] def drop(self, **kwargs):
return self._database.drop_collection(self._collection_name)
@defer.inlineCallbacks
[docs] def create_index(self, sort_fields, **kwargs):
if not isinstance(sort_fields, qf.sort):
raise TypeError("sort_fields must be an instance of filter.sort")
if "name" not in kwargs:
name = self._gen_index_name(sort_fields["orderby"])
else:
name = kwargs.pop("name")
key = SON()
for k, v in sort_fields["orderby"]:
key.update({k: v})
index = SON(dict(
ns=str(self),
name=name,
key=key
))
if "drop_dups" in kwargs:
kwargs["dropDups"] = kwargs.pop("drop_dups")
if "bucket_size" in kwargs:
kwargs["bucketSize"] = kwargs.pop("bucket_size")
index.update(kwargs)
yield self._database.system.indexes.insert(index, safe=True)
defer.returnValue(name)
[docs] def ensure_index(self, sort_fields, **kwargs):
# ensure_index is an alias of create_index since we are not
# keeping an index cache same way pymongo does
return self.create_index(sort_fields, **kwargs)
[docs] def drop_index(self, index_identifier):
if isinstance(index_identifier, (bytes, unicode)):
name = index_identifier
elif isinstance(index_identifier, qf.sort):
name = self._gen_index_name(index_identifier["orderby"])
else:
raise TypeError("index_identifier must be a name or instance of filter.sort")
return self._database.command("deleteIndexes", self._collection_name,
index=name,
allowable_errors=["ns not found"])
[docs] def drop_indexes(self):
return self.drop_index("*")
@defer.inlineCallbacks
[docs] def rename(self, new_name):
to = "%s.%s" % (str(self._database), new_name)
return self._database("admin").command("renameCollection", str(self), to=to)
@defer.inlineCallbacks
[docs] def distinct(self, key, spec=None):
params = {"key": key}
if spec:
params["query"] = spec
result = yield self._database.command("distinct", self._collection_name, **params)
defer.returnValue(result.get("values"))
@defer.inlineCallbacks
[docs] def aggregate(self, pipeline, full_response=False):
raw = yield self._database.command("aggregate", self._collection_name, pipeline=pipeline)
if full_response:
defer.returnValue(raw)
defer.returnValue(raw.get("result"))
@defer.inlineCallbacks
[docs] def map_reduce(self, map, reduce, full_response=False, **kwargs):
params = {"map": map, "reduce": reduce}
params.update(**kwargs)
raw = yield self._database.command("mapreduce", self._collection_name, **params)
if full_response:
defer.returnValue(raw)
defer.returnValue(raw.get("results"))
@defer.inlineCallbacks
[docs] def find_and_modify(self, query=None, update=None, upsert=False, **kwargs):
no_obj_error = "No matching object found"
if not update and not kwargs.get("remove", None):
raise ValueError("Must either update or remove")
if update and kwargs.get("remove", None):
raise ValueError("Can't do both update and remove")
params = kwargs
# No need to include empty args
if query:
params["query"] = query
if update:
params["update"] = update
if upsert:
params["upsert"] = upsert
result = yield self._database.command("findAndModify", self._collection_name,
allowable_errors=[no_obj_error],
**params)
if not result["ok"]:
if result["errmsg"] == no_obj_error:
defer.returnValue(None)
else:
# Should never get here because of allowable_errors
raise ValueError("Unexpected Error: %s" % (result,))
defer.returnValue(result.get("value"))
# Distinct findAndModify utility method is needed because traditional
# find_and_modify() accepts `sort` kwarg as dict and passes it to
# MongoDB command without conversion. But in find_one_and_*
# methods we want to take `filter.sort` instances
@defer.inlineCallbacks
def _new_find_and_modify(self, filter, projection, sort, upsert=None,
return_document=ReturnDocument.BEFORE, **kwargs):
validate_is_mapping("filter", filter)
if not isinstance(return_document, bool):
raise ValueError("return_document must be ReturnDocument.BEFORE "
"or ReturnDocument.AFTER")
cmd = SON([("findAndModify", self._collection_name),
("query", filter),
("new", return_document)])
cmd.update(kwargs)
if projection is not None:
cmd["fields"] = self._normalize_fields_projection(projection)
if sort is not None:
cmd["sort"] = dict(sort["orderby"])
if upsert is not None:
validate_boolean("upsert", upsert)
cmd["upsert"] = upsert
no_obj_error = "No matching object found"
result = yield self._database.command(cmd, allowable_errors=[no_obj_error])
defer.returnValue(result.get("value"))
[docs] def find_one_and_delete(self, filter, projection=None, sort=None, **kwargs):
return self._new_find_and_modify(filter, projection, sort, remove=True, **kwargs)
[docs] def find_one_and_replace(self, filter, replacement, projection=None, sort=None,
upsert=False, return_document=ReturnDocument.BEFORE, **kwargs):
validate_ok_for_replace(replacement)
return self._new_find_and_modify(filter, projection, sort, upsert, return_document,
update=replacement, **kwargs)
[docs] def find_one_and_update(self, filter, update, projection=None, sort=None,
upsert=False, return_document=ReturnDocument.BEFORE, **kwargs):
validate_ok_for_update(update)
return self._new_find_and_modify(filter, projection, sort, upsert, return_document,
update=update, **kwargs)