Unverified Commit 3d7db864 by Peter Weidenbach Committed by GitHub

Merge pull request #1 from fkie-cad/bugfix_and_refactoring

unused parameter bugfix + refactoring
parents ec35d19d 7e9a25fa
import logging import logging
from typing import List, Optional
from bson.son import SON from bson.son import SON
from pymongo.command_cursor import CommandCursor
def get_list_of_all_values(collection, object_path, unwind=False, match=None): def get_list_of_all_values(collection, object_path, unwind=False, match=None):
''' '''
Get a list of unique values on a specific object path in a collection. Get a list of unique values on a specific object path in a collection.
An Optional search string (match) can be added. An Optional search string (match) can be added.
If additional_field is set, all values of this field for each
:param collection: mongo collection to look at :param collection: mongo collection to look at
:type collection: pymongo.collection :type collection: pymongo.collection.Collection
:param object_path: mongo object path :param object_path: mongo object path
:type object_path: str :type object_path: str
:param unwind: if true: handle list entries as single values :param unwind: if true: handle list entries as single values
:type unwind: bool :type unwind: bool
:param match: mongo search string :param match: mongo search string
:type match: dict :type match: dict, optional
:return: list :return: list
''' '''
pipeline = [] pipeline = _build_pipeline(object_path, {'_id': object_path}, unwind, SON([('_id', 1)]), match)
if match is not None:
pipeline.append({'$match': match})
pipeline.extend([
{'$group': {'_id': object_path}},
{'$sort': SON([('_id', 1)])}
])
if unwind:
old_pipe = pipeline
pipeline = [{'$unwind': object_path}]
pipeline.extend(old_pipe)
result = _get_list_of_aggregate_list(list(collection.aggregate(pipeline))) result = _get_list_of_aggregate_list(list(collection.aggregate(pipeline)))
logging.debug(result) logging.debug(result)
return result return result
def get_list_of_all_values_and_collect_information_of_additional_field(collection, object_path, additional_information_object_path, unwind=False, match=None): def get_list_of_all_values_and_collect_information_of_additional_field(
collection, object_path, additional_information_object_path, unwind=False, match=None):
''' '''
Get a list of unique values and a collection of additional information on a specific object path in a collection. Get a list of unique values and a collection of additional information on a specific object path in a collection.
An Optional search string (match) can be added. An Optional search string (match) can be added.
:param collection: mongo collection to look at :param collection: mongo collection to look at
:type collection: pymongo.collection :type collection: pymongo.collection.Collection
:param object_path: mongo object path :param object_path: mongo object path
:type object_path: str :type object_path: str
:param additional_information_object_path: field of the additional information :param additional_information_object_path: field of the additional information
...@@ -47,27 +42,41 @@ def get_list_of_all_values_and_collect_information_of_additional_field(collectio ...@@ -47,27 +42,41 @@ def get_list_of_all_values_and_collect_information_of_additional_field(collectio
:param unwind: if true: handle list entries as single values :param unwind: if true: handle list entries as single values
:type unwind: bool :type unwind: bool
:param match: mongo search string :param match: mongo search string
:type match: dict :type match: dict, optional
:return: {<VALUE>:[<ADDITIONAL_INFORMATION_1>, ...], ...} :return: {<VALUE>:[<ADDITIONAL_INFORMATION_1>, ...], ...}
''' '''
pipeline = [] logging.warning('deprecation warning: this method will be removed in a future release')
if match is not None: return get_all_value_combinations_of_fields(collection, object_path, additional_information_object_path, unwind, match)
pipeline.append({'$match': match})
pipeline.extend([
{'$group': {'_id': object_path, 'additional_information': {'$addToSet': '$_id'}}}, def get_all_value_combinations_of_fields(collection, primary_field, secondary_field, unwind=False, match=None):
{'$sort': SON([('_id', 1)])} '''
]) Get a dictionary with all unique values of a field as keys and a list of all unique values that a second field takes
if unwind: on as values (on a specific object path in a collection). An Optional search string (match) can be added.
old_pipe = pipeline
pipeline = [{'$unwind': object_path}] :param collection: mongo collection to look at
pipeline.extend(old_pipe) :type collection: pymongo.collection.Collection
:param primary_field: mongo object path
:type primary_field: str
:param secondary_field: field of the additional information
:type secondary_field: str
:param unwind: if true: handle list entries as single values
:type unwind: bool
:param match: mongo search string
:type match: dict, optional
:return: {<VALUE>:[<ADDITIONAL_INFORMATION_1>, ...], ...}
'''
pipeline = _build_pipeline(
primary_field, {'_id': primary_field, 'additional_information': {'$addToSet': secondary_field}},
unwind, SON([('_id', 1)]), match
)
result = list(collection.aggregate(pipeline)) result = list(collection.aggregate(pipeline))
result = _get_dict_from_aggregat_list(result) result = _get_dict_from_aggregate_list(result)
logging.debug(result) logging.debug(result)
return result return result
def _get_dict_from_aggregat_list(ag_list): def _get_dict_from_aggregate_list(ag_list):
result = {} result = {}
for item in ag_list: for item in ag_list:
result[item['_id']] = item['additional_information'] result[item['_id']] = item['additional_information']
...@@ -75,38 +84,26 @@ def _get_dict_from_aggregat_list(ag_list): ...@@ -75,38 +84,26 @@ def _get_dict_from_aggregat_list(ag_list):
def _get_list_of_aggregate_list(ag_list): def _get_list_of_aggregate_list(ag_list):
result = [] return [item['_id'] for item in ag_list]
for item in ag_list:
result.append(item['_id'])
return result
def get_objects_and_count_of_occurrence(collection, object_path, unwind=False, match=None): def get_objects_and_count_of_occurrence(collection, object_path, unwind=False, match=None):
''' '''
Get a list of unique values and their occurences on a specific object path in a collection. Get a list of unique values and their occurrences on a specific object path in a collection.
An Optional search string (match) can be added. An Optional search string (match) can be added.
:param collection: mongo collection to look at :param collection: mongo collection to look at
:type collection: pymongo.collection :type collection: pymongo.collection.Collection
:param object_path: mongo object path :param object_path: mongo object path
:type object_path: str :type object_path: str
:param unwind: if true: handle list entries as single values :param unwind: if true: handle list entries as single values
:type unwind: bool :type unwind: bool
:param match: mongo search string :param match: mongo search string
:type match: dict :type match: dict, optional
:return: [{'_id': <VALUE>, 'count': <OCCURENCES>}, ...] :return: [{'_id': <VALUE>, 'count': <OCCURRENCES>}, ...]
''' '''
pipeline = [] pipeline = _build_pipeline(object_path, {'_id': object_path, 'count': {'$sum': 1}}, unwind,
if match is not None: SON([('count', -1), ('_id', -1)]), match)
pipeline.append({'$match': match})
pipeline.extend([
{'$group': {'_id': object_path, 'count': {'$sum': 1}}},
{'$sort': SON([('count', -1), ('_id', -1)])}
])
if unwind:
old_pipe = pipeline
pipeline = [{'$unwind': object_path}]
pipeline.extend(old_pipe)
result = list(collection.aggregate(pipeline)) result = list(collection.aggregate(pipeline))
logging.debug(result) logging.debug(result)
return result return result
...@@ -118,7 +115,7 @@ def get_field_sum(collection, object_path, match=None): ...@@ -118,7 +115,7 @@ def get_field_sum(collection, object_path, match=None):
An Optional search string (match) can be added. An Optional search string (match) can be added.
:param collection: mongo collection to look at :param collection: mongo collection to look at
:type collection: pymongo.collection :type collection: pymongo.collection.Collection
:param object_path: mongo object path :param object_path: mongo object path
:type object_path: str :type object_path: str
:param match: mongo search string :param match: mongo search string
...@@ -134,7 +131,7 @@ def get_field_average(collection, object_path, match=None): ...@@ -134,7 +131,7 @@ def get_field_average(collection, object_path, match=None):
An Optional search string (match) can be added. An Optional search string (match) can be added.
:param collection: mongo collection to look at :param collection: mongo collection to look at
:type collection: pymongo.collection :type collection: pymongo.collection.Collection
:param object_path: mongo object path :param object_path: mongo object path
:type object_path: str :type object_path: str
:param match: mongo search string :param match: mongo search string
...@@ -145,12 +142,22 @@ def get_field_average(collection, object_path, match=None): ...@@ -145,12 +142,22 @@ def get_field_average(collection, object_path, match=None):
def get_field_execute_operation(operation, collection, object_path, match=None): def get_field_execute_operation(operation, collection, object_path, match=None):
pipeline = _build_pipeline(object_path, {'_id': 'null', 'total': {operation: object_path}}, match=match)
query_result = collection.aggregate(pipeline)
try:
return query_result.next()['total']
except StopIteration:
return 0
def _build_pipeline(object_path: str, group: dict, unwind: bool = False, sort_key: Optional[SON] = None,
match: Optional[dict] = None) -> List[dict]:
pipeline = [] pipeline = []
if match is not None: if unwind:
pipeline.append({'$unwind': object_path})
if match:
pipeline.append({'$match': match}) pipeline.append({'$match': match})
pipeline.append({'$group': {'_id': 'null', 'total': {operation: object_path}}}) pipeline.append({'$group': group})
tmp = collection.aggregate(pipeline) if sort_key:
result = 0 pipeline.append({'$sort': sort_key})
for item in tmp: return pipeline
result = item['total']
return result
from setuptools import setup, find_packages from setuptools import setup, find_packages
VERSION = '0.3.3' VERSION = '0.4.0'
setup( setup(
name='common_helper_mongo', name='common_helper_mongo',
......
from common_helper_mongo.aggregate import get_objects_and_count_of_occurrence,\ from common_helper_mongo.aggregate import (
get_field_sum, get_field_average, get_list_of_all_values,\ get_all_value_combinations_of_fields, get_field_average, get_field_sum, get_list_of_all_values,
get_list_of_all_values_and_collect_information_of_additional_field get_list_of_all_values_and_collect_information_of_additional_field, get_objects_and_count_of_occurrence,
)
from tests.base_class_database_test import MongoDbTest from tests.base_class_database_test import MongoDbTest
...@@ -25,31 +26,40 @@ class TestAggregate(MongoDbTest): ...@@ -25,31 +26,40 @@ class TestAggregate(MongoDbTest):
self.assertEqual(len(result), 4) self.assertEqual(len(result), 4)
self.assertEqual(result[0], "a") self.assertEqual(result[0], "a")
def test_get_list_of_all_values_and_collect_information_of_additional_field(self): def test_get_all_value_combinations_of_fields(self):
self.test_collection.insert_many([
{"test_list": ["a", "b"], "test_value": 1},
{"test_list": ["c", "d"], "test_value": 2},
{"test_list": ["a", "d"], "test_value": 1}
])
result = get_all_value_combinations_of_fields(self.test_collection, "$test_list", "$test_value", unwind=True)
assert result == {'a': [1], 'b': [1], 'c': [2], 'd': [1, 2]}
def test_get_all_value_combinations_of_fields_id(self):
self.add_list_test_data() self.add_list_test_data()
result = get_list_of_all_values_and_collect_information_of_additional_field(self.test_collection, "$test_list", "$_id", unwind=True, match=None) result = get_all_value_combinations_of_fields(self.test_collection, "$test_list", "$_id", unwind=True)
self.assertIsInstance(result, dict, "result should be a dict") self.assertIsInstance(result, dict, "result should be a dict")
self.assertEqual(len(result.keys()), 4, "number of results not correct") self.assertEqual(len(result.keys()), 4, "number of results not correct")
self.assertEqual(len(result['c']), 2, "c should have two related object ids") self.assertEqual(len(result['c']), 2, "c should have two related object ids")
self.assertEqual(len(result['a']), 1, "a should have one related object id") self.assertEqual(len(result['a']), 1, "a should have one related object id")
def test_get_objects_and_count_of_occurence(self): def test_get_objects_and_count_of_occurrence(self):
self.add_simple_test_data() self.add_simple_test_data()
result = get_objects_and_count_of_occurrence(self.test_collection, "$test_txt", unwind=False, match=None) result = get_objects_and_count_of_occurrence(self.test_collection, "$test_txt", unwind=False, match=None)
self.assertEqual(len(result), 10, "number of results not correct") self.assertEqual(len(result), 10, "number of results not correct")
self.assertEqual(result[0]['_id'], "item 1", "should be the fist element because it has two ocurrences") self.assertEqual(result[0]['_id'], "item 1", "should be the fist element because it has two occurrences")
self.assertEqual(result[0]['count'], 2) self.assertEqual(result[0]['count'], 2)
def test_get_objects_and_count_unwind(self): def test_get_objects_and_count_unwind(self):
self.add_list_test_data() self.add_list_test_data()
result = get_objects_and_count_of_occurrence(self.test_collection, "$test_list", unwind=True, match=None) result = get_objects_and_count_of_occurrence(self.test_collection, "$test_list", unwind=True, match=None)
self.assertEqual(len(result), 4, "number of results not correct") self.assertEqual(len(result), 4, "number of results not correct")
self.assertEqual(result[0]['_id'], "c", "should be the first element because it has two ocurrences") self.assertEqual(result[0]['_id'], "c", "should be the first element because it has two occurrences")
self.assertEqual(result[0]['count'], 2) self.assertEqual(result[0]['count'], 2)
def test_get_objects_and_count_match(self): def test_get_objects_and_count_match(self):
self.add_simple_test_data() self.add_simple_test_data()
result = get_objects_and_count_of_occurrence(self.test_collection, "$test_txt", unwind="False", match={"test_int": 0}) result = get_objects_and_count_of_occurrence(self.test_collection, "$test_txt", unwind=False, match={"test_int": 0})
self.assertEqual(len(result), 1, "number of results not correct") self.assertEqual(len(result), 1, "number of results not correct")
self.assertEqual(result[0]['_id'], "item 0") self.assertEqual(result[0]['_id'], "item 0")
...@@ -67,3 +77,13 @@ class TestAggregate(MongoDbTest): ...@@ -67,3 +77,13 @@ class TestAggregate(MongoDbTest):
self.add_simple_test_data() self.add_simple_test_data()
result = get_field_sum(self.test_collection, "$test_int", match={"test_int": {"$lt": 5}}) result = get_field_sum(self.test_collection, "$test_int", match={"test_int": {"$lt": 5}})
self.assertEqual(result, 10) self.assertEqual(result, 10)
def test_get_field_execute_operation_empty(self):
result = get_field_sum(self.test_collection, "$test_int")
self.assertEqual(result, 0)
def test_get_list_of_all_values_and_collect_information_of_additional_field(self):
self.add_list_test_data()
with self.assertLogs() as logger:
get_list_of_all_values_and_collect_information_of_additional_field(self.test_collection, "$test_list", "$_id", unwind=True)
assert 'deprecation warning' in logger.output.pop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment