From d0a92797d0346b192e504afebeb0dec483105a4c Mon Sep 17 00:00:00 2001 From: Peter Weidenbach <weidenba@cs.uni-bonn.de> Date: Tue, 13 Dec 2016 13:20:41 +0100 Subject: [PATCH] aggregation feature added --- common_helper_mongo/__init__.py | 4 +++- common_helper_mongo/aggregate.py | 19 +++++++++++++++++++ tests/base_class_database_test.py | 10 ++++++++++ tests/test_aggregate.py | 24 ++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 common_helper_mongo/aggregate.py create mode 100644 tests/test_aggregate.py diff --git a/common_helper_mongo/__init__.py b/common_helper_mongo/__init__.py index ddedbfb..b667068 100644 --- a/common_helper_mongo/__init__.py +++ b/common_helper_mongo/__init__.py @@ -1,5 +1,7 @@ from .gridfs import overwrite_file +from .aggregate import get_objects_and_count_of_occurrence __all__ = [ - 'overwrite_file' + 'overwrite_file', + 'get_objects_and_count_of_occurrence' ] diff --git a/common_helper_mongo/aggregate.py b/common_helper_mongo/aggregate.py new file mode 100644 index 0000000..6b987a4 --- /dev/null +++ b/common_helper_mongo/aggregate.py @@ -0,0 +1,19 @@ +import logging +from bson.son import SON + + +def get_objects_and_count_of_occurrence(collection, object_path, unwind=False, match=None): + pipeline = [] + if match is not None: + pipeline.append({"$match": match}) + pipeline.extend([ + {"$group": {"_id": object_path, "count": {"$sum": 1}}}, + {"$sort": SON([("count", -1), ("_id", -1)])} + ]) + if unwind: + old_pipe = pipeline + pipeline = [{"$unwind": object_path}] + pipeline.extend(old_pipe) + result = list(collection.aggregate(pipeline)) + logging.debug(result) + return result diff --git a/tests/base_class_database_test.py b/tests/base_class_database_test.py index 29edad4..52cd31a 100644 --- a/tests/base_class_database_test.py +++ b/tests/base_class_database_test.py @@ -7,10 +7,20 @@ class MongoDbTest(unittest.TestCase): def setUp(self): self.mongo_client = MongoClient() self.db = self.mongo_client["common_code_test"] + self.test_collection = self.db.test_data def tearDown(self): self.mongo_client.drop_database(self.db) self.mongo_client.close() + def add_simple_test_data(self): + for i in range(10): + self.test_collection.insert_one({"test_int": i, "test_txt": "item {}".format(i)}) + self.test_collection.insert_one({"test_txt": "item 1"}) + + def add_list_test_data(self): + self.test_collection.insert_one({"test_list": ["a", "b", "c"]}) + self.test_collection.insert_one({"test_list": ["c", "d"]}) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py new file mode 100644 index 0000000..92be4b0 --- /dev/null +++ b/tests/test_aggregate.py @@ -0,0 +1,24 @@ +from common_helper_mongo.aggregate import get_objects_and_count_of_occurrence +import unittest +from tests.base_class_database_test import MongoDbTest + + +class TestAggregate(MongoDbTest): + + def test_get_objects_and_count_of_occurence(self): + self.add_simple_test_data() + result = get_objects_and_count_of_occurrence(self.test_collection, "$test_txt", unwind=False, match=None) + self.assertEqual(len(result), 10, "number of results not correct") + self.assertEqual(result[0]['_id'], "item 1", "should be the fist element because it has two ocurrences") + self.assertEqual(result[0]['count'], 2) + + def test_get_objects_and_count_unwind(self): + self.add_list_test_data() + result = get_objects_and_count_of_occurrence(self.test_collection, "$test_list", unwind=True, match=None) + self.assertEqual(len(result), 4, "number of results not correct") + self.assertEqual(result[0]['_id'], "c", "should be the first element because it has two ocurrences") + self.assertEqual(result[0]['count'], 2) + + +if __name__ == "__main__": + unittest.main() -- libgit2 0.26.0