diff --git a/ietf/utils/management/commands/makefixture.py b/ietf/utils/management/commands/makefixture.py index e715ea2be..16e98d8e8 100644 --- a/ietf/utils/management/commands/makefixture.py +++ b/ietf/utils/management/commands/makefixture.py @@ -1,10 +1,42 @@ -# From http://djangosnippets.org/snippets/918/ +# From https://github.com/ericholscher/django-test-utils/blob/master/test_utils/management/commands/makefixture.py +""" +"Make fixture" command. + +Highly useful for making test fixtures. Use it to pick only few items +from your data to serialize, restricted by primary keys. By default +command also serializes foreign keys and m2m relations. You can turn +off related items serialization with --skip-related option. + +How to use: +python manage.py makefixture + +will display what models are installed + +python manage.py makefixture User[:3] +or +python manage.py makefixture auth.User[:3] +or +python manage.py makefixture django.contrib.auth.User[:3] + +will serialize users with ids 1 and 2, with assigned groups, permissions +and content types. + +python manage.py makefixture YourModel[3] YourModel[6:10] + +will serialize YourModel with key 3 and keys 6 to 9 inclusively. + +Of course, you can serialize whole tables, and also different tables at +once, and use options of dumpdata: + +python manage.py makefixture --format=xml --indent=4 YourModel[3] AnotherModel auth.User[:5] auth.Group +""" +# From http://www.djangosnippets.org/snippets/918/ #save into anyapp/management/commands/makefixture.py #or back into django/core/management/commands/makefixture.py #v0.1 -- current version #known issues: -#no support for generic relations +#no support for generic relations #no support for one-to-one relations from optparse import make_option from django.core import serializers @@ -15,8 +47,6 @@ from django.db.models.fields.related import ForeignKey from django.db.models.fields.related import ManyToManyField from django.db.models.loading import get_models -import debug - DEBUG = False def model_name(m): @@ -29,20 +59,36 @@ class Command(LabelCommand): option_list = BaseCommand.option_list + ( make_option('--skip-related', default=True, action='store_false', dest='propagate', help='Specifies if we shall not add related objects.'), + make_option('--reverse', default=[], action='append', dest='reverse', + help="Reverse relations to follow (e.g. 'Job.task_set')."), make_option('--format', default='json', dest='format', help='Specifies the output serialization format for fixtures.'), make_option('--indent', default=None, dest='indent', type='int', help='Specifies the indent level to use when pretty-printing output'), - make_option('--include-reverse', default=False, action='store_true', dest='reverse', - help='Add reverse related objects too'), ) - + def handle_reverse(self, **options): + follow_reverse = options.get('reverse', []) + to_reverse = {} + for arg in follow_reverse: + try: + model_name, related_set_name = arg.rsplit(".", 1) + except: + raise CommandError("Bad fieldname on '--reverse %s'" % arg) + model = self.get_model_from_name(model_name) + try: + getattr(model, related_set_name) + except AttributeError: + raise CommandError("Field '%s' does not exist on model '%s'" % ( + related_set_name, model_name)) + to_reverse.setdefault(model, []).append(related_set_name) + return to_reverse + def handle_models(self, models, **options): format = options.get('format','json') indent = options.get('indent',None) show_traceback = options.get('traceback', False) propagate = options.get('propagate', True) - opt_reverse = options.get('reverse', False) + follow_reverse = self.handle_reverse(**options) # Check that the serialization format exists; this is a shortcut to # avoid collating all the objects and _then_ failing. @@ -56,7 +102,7 @@ class Command(LabelCommand): objects = [] for model, slice in models: - if isinstance(slice, basestring): + if isinstance(slice, basestring) and slice: objects.extend(model._default_manager.filter(pk__exact=slice)) elif not slice or type(slice) is list: items = model._default_manager.all() @@ -68,32 +114,16 @@ class Command(LabelCommand): objects.extend(items) else: raise CommandError("Wrong slice: %s" % slice) - + all = objects - collected = set([(x.__class__, x.pk) for x in all]) - - if opt_reverse: - related = [] - for x in objects: - attribs = [] - for name in dir(x): - try: - attribs.append(getattr(x, name)) - except AttributeError: - pass - for o in attribs: - if "django.db.models.fields.related.RelatedManager object" in repr(o): - for new in o.all(): - collected.add((new.__class__, new.pk)) - related.append(new) - all.extend(related) - if propagate: + collected = set([(x.__class__, x.pk) for x in all]) while objects: related = [] for x in objects: if DEBUG: print "Adding %s[%s]" % (model_name(x), x.pk) + # follow forward relation fields for f in x.__class__._meta.fields + x.__class__._meta.many_to_many: if isinstance(f, ForeignKey): new = getattr(x, f.name) # instantiate object @@ -105,9 +135,16 @@ class Command(LabelCommand): if new and not (new.__class__, new.pk) in collected: collected.add((new.__class__, new.pk)) related.append(new) + # follow reverse relations as requested + for reverse_field in follow_reverse.get(x.__class__, []): + mgr = getattr(x, reverse_field) + for new in mgr.all(): + if new and not (new.__class__, new.pk) in collected: + collected.add((new.__class__, new.pk)) + related.append(new) objects = related - all.extend(related) - + all.extend(objects) + try: return serializers.serialize(format, all, indent=indent) except Exception, e: @@ -118,6 +155,24 @@ class Command(LabelCommand): def get_models(self): return [(m, model_name(m)) for m in get_models()] + def get_model_from_name(self, search): + """Given a name of a model, return the model object associated with it + +The name can be either fully specified or uniquely matching the +end of the model name. e.g. +django.contrib.auth.User +or +auth.User +raises CommandError if model can't be found or uniquely determined +""" + models = [model for model, name in self.get_models() + if name.endswith('.'+name) or name == search] + if not models: + raise CommandError("Unknown model: %s" % search) + if len(models)>1: + raise CommandError("Ambiguous model name: %s" % search) + return models[0] + def handle_label(self, labels, **options): parsed = [] for label in labels: @@ -125,17 +180,12 @@ class Command(LabelCommand): if '[' in label: search, pks = label.split('[', 1) slice = '' - if ':' in pks: + if ':' in pks: slice = pks.rstrip(']').split(':', 1) - elif pks: + elif pks: slice = pks.rstrip(']') - models = [model for model, name in self.get_models() - if name.endswith('.'+search) or name == search] - if not models: - raise CommandError("Wrong model: %s" % search) - if len(models)>1: - raise CommandError("Ambiguous model name: %s" % search) - parsed.append((models[0], slice)) + model = self.get_model_from_name(search) + parsed.append((model, slice)) return self.handle_models(parsed, **options) def list_models(self): @@ -149,5 +199,5 @@ class Command(LabelCommand): output = [] label_output = self.handle_label(labels, **options) if label_output: - output.append(label_output) + output.append(label_output) return '\n'.join(output)