Source code for save_to_db.adapters.django_adapter

from django.apps import apps

from django.db import models, transaction
from django.db.models import Q

from django.db.models.fields.related import OneToOneField, OneToOneRel
from django.db.models.fields.related import ManyToManyField, ManyToManyRel
from django.db.models.fields.related import ManyToOneRel
from django.db.models.fields.related import ForeignKey

# from django.db.models.fields.related_descriptors import \
#     ReverseOneToOneDescriptor, ReverseManyToOneDescriptor

from .utils.adapter_base import AdapterBase
from .utils.column_type import ColumnType
from .utils.relation_type import RelationType


# --- start debug tool ---------------------------------------------------------
from django.db import connection
from functools import wraps


class _DebugQueries(object):
    def start(self):
        self.start_query_count = len(connection.queries)

    def end(self):
        print("+" * 100)
        for i, q in enumerate(connection.queries[self.start_query_count :]):
            print("-" * 80)
            print(i, q["sql"])
        print("*" * 100, flush=True)

    def __enter__(self):
        self.start()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end()


dq = _DebugQueries()


def _dq_decorator(func):
    @wraps(func)
    def dq_decorated(self, *args, **kwargs):
        global dq
        with dq:
            return func(self, *args, **kwargs)

    return dq_decorated


# --- end debug tool -----------------------------------------------------------


[docs]class DjangoAdapter(AdapterBase): """An adapter for working with Django ORM. The `adapter_settings` is a dictionary with next values: - *using* is a database name (a key from `DATABASES` dictionary in the "settings.py" file of a Django project), if it isn't provided "default" database is used. .. note:: If you are going to use transactions provided by this library, set `autocommit` for the database to `False`: .. code-block:: Python from django.db import transaction from save_to_db import Persister from save_to_db.adapters import DjangoAdapter persister = Persister(DjangoAdapter(adapter_settings={})) # Next function takes a `using` argument which should be the # name of a database. If it isn’t provided, Django uses the # "default" database. transaction.set_autocommit(False, using='database_name') try: persister.persist(save_to_db_item) persister.commit() except: persister.rollback() Alternativy, you can use Django transactions directly: .. code-block:: Python from django.db import transaction from save_to_db import Persister from save_to_db.adapters import DjangoAdapter persister = Persister(DjangoAdapter(adapter_settings={})) with transaction.atomic(): persister.persist(save_to_db_item) """ COMPOSITE_KEYS_SUPPORTED = False REVERSE_MODEL_AUTOUPDATE_SUPPORTED = False SAVE_MODEL_BEFORE_COMMIT = True def __init__(self, adapter_settings): super().__init__(adapter_settings) self.using = adapter_settings.get("using", "default") # --- general methods ------------------------------------------------------ @classmethod def is_usable(cls, model_cls): return issubclass(model_cls, models.Model) def commit(self): transaction.commit(using=self.using) def rollback(self): transaction.rollback(using=self.using) @classmethod def iter_fields(cls, model_cls): foreign_key_classes = ( ForeignKey, ManyToOneRel, OneToOneField, OneToOneRel, ManyToManyField, ManyToManyRel, ) type_to_const = ( (models.BinaryField, ColumnType.BINARY), (models.TextField, ColumnType.TEXT), (models.CharField, ColumnType.STRING), ((models.AutoField, models.IntegerField), ColumnType.INTEGER), ((models.BooleanField, models.NullBooleanField), ColumnType.BOOLEAN), (models.DateTimeField, ColumnType.DATETIME), (models.FloatField, ColumnType.FLOAT), (models.DecimalField, ColumnType.DECIMAL), (models.DateField, ColumnType.DATE), (models.TimeField, ColumnType.TIME), (models.DateTimeField, ColumnType.DATETIME), ) for field in model_cls._meta.get_fields(): if isinstance(field, foreign_key_classes): continue yield_type = ColumnType.OTHER for field_class, column_type in type_to_const: if isinstance(field, field_class): yield_type = column_type break yield field.name, yield_type @classmethod def iter_relations(cls, model_cls): foreign_key_classes = ( ForeignKey, ManyToOneRel, OneToOneField, OneToOneRel, ManyToManyField, ManyToManyRel, ) type_to_const = [ [(OneToOneField, OneToOneRel), RelationType.ONE_TO_ONE], [ForeignKey, RelationType.MANY_TO_ONE], [ManyToOneRel, RelationType.ONE_TO_MANY], [(ManyToManyField, ManyToManyRel), RelationType.MANY_TO_MANY], ] for field in model_cls._meta.get_fields(): if not isinstance(field, foreign_key_classes): continue for field_class, column_type in type_to_const: if isinstance(field, field_class): yield_type = column_type break field_name = field.name if hasattr(field, "get_accessor_name"): field_name = field.get_accessor_name() remote_field_name = None if ( not hasattr(field.remote_field, "is_hidden") or not field.remote_field.is_hidden() ): remote_field_name = field.remote_field.name if hasattr(field.remote_field, "get_accessor_name"): remote_field_name = field.remote_field.get_accessor_name() yield field_name, field.related_model, yield_type, remote_field_name @classmethod def iter_required_fields(cls, model_cls): for field in model_cls._meta.get_fields(): if isinstance(field, (models.AutoField, ManyToManyField)): continue if not field.null and field.default is models.NOT_PROVIDED: yield field.name @classmethod def iter_unique_field_combinations(cls, model_cls): for unique_constraint_fields in model_cls._meta.unique_together: yield unique_constraint_fields for field in model_cls._meta.get_fields(): if hasattr(field, "field"): # relation if isinstance(field, (OneToOneField, OneToOneRel)): yield { field.name, } # one-to-one always unique continue field = field.field if field.unique: yield { field.name, } @classmethod def get_table_fullname(cls, model_cls): return model_cls._meta.db_table def get_model_cls_by_table_fullname(self, name): for models in apps.all_models.values(): for model in models.values(): if self.get_table_fullname(model) == name: return model def iter_all_models(self): for models in apps.all_models.values(): for model in models.values(): yield model # --- methods for working with items --------------------------------------- def get(self, items_and_fkeys): # --- first getting item models from database -------------------------- if not items_and_fkeys: return [] all_items_filters = Q() getters = items_and_fkeys[0][0].getters for item, fkeys in items_and_fkeys: one_item_filters = Q() for group in getters: group_filters = Q() skip_group_filters = False # in case related item was not in database for field_name in group: if field_name not in item: skip_group_filters = True break if field_name in item.fields: field_value = item[field_name] group_filters &= Q(**{field_name: field_value}) elif field_name in item.relations: related_models = fkeys.get(field_name) if not related_models: # failed to get or created related model before skip_group_filters = True break relation = item.relations[field_name] if not relation["relation_type"].is_x_to_many(): group_filters &= Q(**{field_name: related_models[0]}) else: contains_any = Q() for related_model in related_models: contains_any |= Q(**{field_name: related_model}) group_filters &= contains_any if not skip_group_filters and len(group_filters): one_item_filters |= group_filters if one_item_filters.children: all_items_filters |= one_item_filters if not len(all_items_filters): return [] return ( item.model_cls.objects.db_manager(self.using) .filter(all_items_filters) .all() ) def delete(self, model): model.delete(using=self.using) def create_blank_model(self, model_cls): return model_cls() def add_related_models(self, model, fkey, related_models): getattr(model, fkey).add(*related_models) def clear_related_models(self, model, fkey): related = getattr(model, fkey) if hasattr(related, "clear"): getattr(model, fkey).clear() else: raise Exception("Cannot clear required field.") def related_x_to_many_exists(self, model, fkey): return getattr(model, fkey).exists() def related_x_to_many_contains(self, model, fkey, child_models): if not child_models: return [] # no composite PK in Django pk_name = self.get_primary_key_names(model)[0] pk_list = [getattr(child_model, pk_name) for child_model in child_models] contained_models = getattr(model, fkey).filter( **{"{}__in".format(pk_name): pk_list} ) result = [] for contained_model in contained_models: for child_model in child_models: if child_model == contained_model: result.append(child_model) break return result @classmethod def get_primary_key_names(cls, model_cls): return (model_cls._meta.pk.name,) def save_model(self, model): model.save(using=self.using) @classmethod def __get_select_and_keep_fielter(cls, selectors, keepers): total_select_filter = Q() for selectors_entry in selectors: select_filter_set = Q() for field_name, value in selectors_entry.items(): select_filter_set &= Q(**{field_name: value}) total_select_filter |= select_filter_set total_keep_filter = Q() for keepers_entry in keepers: keep_fielter_set = Q() for field_name, value in keepers_entry.items(): keep_fielter_set |= ~Q(**{field_name: value}) total_keep_filter &= keep_fielter_set if total_select_filter or total_keep_filter: return Q(total_select_filter, total_keep_filter) return None def execute_delete(self, model_cls, selectors, keepers): filters = self.__get_select_and_keep_fielter(selectors, keepers) if filters is None: return model_cls.objects.db_manager(self.using).filter(filters).delete() def execute_unref(self, parent, fkey, selectors, keepers): filters = self.__get_select_and_keep_fielter(selectors, keepers) if filters is None: return child_relation = getattr(parent, fkey) children = getattr(parent, fkey).filter(filters).all() child_relation.remove(*children) def get_related_x_to_many(self, model, fkey): return getattr(model, fkey).all() # --- methods for tests ---------------------------------------------------- def get_all_models(self, model_cls, sort_key=None): models = list(model_cls.objects.db_manager(self.using).all()) if sort_key: models.sort(key=sort_key) return models