datatracker/south/modelsparser.py
2010-07-21 12:48:05 +00:00

429 lines
14 KiB
Python

"""
Parsing module for models.py files. Extracts information in a more reliable
way than inspect + regexes.
Now only used as a fallback when introspection and the South custom hook both fail.
"""
import re
import inspect
import parser
import symbol
import token
import keyword
import datetime
from django.db import models
from django.contrib.contenttypes import generic
from django.utils.datastructures import SortedDict
from django.core.exceptions import ImproperlyConfigured
def name_that_thing(thing):
"Turns a symbol/token int into its name."
for name in dir(symbol):
if getattr(symbol, name) == thing:
return "symbol.%s" % name
for name in dir(token):
if getattr(token, name) == thing:
return "token.%s" % name
return str(thing)
def thing_that_name(name):
"Turns a name of a symbol/token into its integer value."
if name in dir(symbol):
return getattr(symbol, name)
if name in dir(token):
return getattr(token, name)
raise ValueError("Cannot convert '%s'" % name)
def prettyprint(tree, indent=0, omit_singles=False):
"Prettyprints the tree, with symbol/token names. For debugging."
if omit_singles and isinstance(tree, tuple) and len(tree) == 2:
return prettyprint(tree[1], indent, omit_singles)
if isinstance(tree, tuple):
return " (\n%s\n" % "".join([prettyprint(x, indent+1) for x in tree]) + \
(" " * indent) + ")"
elif isinstance(tree, int):
return (" " * indent) + name_that_thing(tree)
else:
return " " + repr(tree)
def isclass(obj):
"Simple test to see if something is a class."
return issubclass(type(obj), type)
def aliased_models(module):
"""
Given a models module, returns a dict mapping all alias imports of models
(e.g. import Foo as Bar) back to their original names. Bug #134.
"""
aliases = {}
for name, obj in module.__dict__.items():
if isclass(obj) and issubclass(obj, models.Model) and obj is not models.Model:
# Test to see if this has a different name to what it should
if name != obj._meta.object_name:
aliases[name] = obj._meta.object_name
return aliases
class STTree(object):
"A syntax tree wrapper class."
def __init__(self, tree):
self.tree = tree
def __eq__(self, other):
return other.tree == self.tree
def __hash__(self):
return hash(self.tree)
@property
def root(self):
return self.tree[0]
@property
def value(self):
return self.tree
def walk(self, recursive=True):
"""
Yields (symbol, subtree) for the entire subtree.
Comes out with node 1, node 1's children, node 2, etc.
"""
stack = [self.tree]
done_outer = False
while stack:
atree = stack.pop()
if isinstance(atree, tuple):
if done_outer:
yield atree[0], STTree(atree)
if recursive or not done_outer:
for bit in reversed(atree[1:]):
stack.append(bit)
done_outer = True
def flatten(self):
"Yields the tokens/symbols in the tree only, in order."
bits = []
for sym, subtree in self.walk():
if sym in token_map:
bits.append(sym)
elif sym == token.NAME:
bits.append(subtree.value)
elif sym == token.STRING:
bits.append(subtree.value)
elif sym == token.NUMBER:
bits.append(subtree.value)
return bits
def reform(self):
"Prints how the tree's input probably looked."
return reform(self.flatten())
def findAllType(self, ntype, recursive=True):
"Returns all nodes with the given type in the tree."
for symbol, subtree in self.walk(recursive=recursive):
if symbol == ntype:
yield subtree
def find(self, selector):
"""
Searches the syntax tree with a CSS-like selector syntax.
You can use things like 'suite simple_stmt', 'suite, simple_stmt'
or 'suite > simple_stmt'. Not guaranteed to return in order.
"""
# Split up the overall parts
patterns = [x.strip() for x in selector.split(",")]
results = []
for pattern in patterns:
# Split up the parts
parts = re.split(r'(?:[\s]|(>))+', pattern)
# Take the first part, use it for results
if parts[0] == "^":
subresults = [self]
else:
subresults = list(self.findAllType(thing_that_name(parts[0])))
recursive = True
# For each remaining part, do something
for part in parts[1:]:
if not subresults:
break
if part == ">":
recursive = False
elif not part:
pass
else:
thing = thing_that_name(part)
newresults = [
list(tree.findAllType(thing, recursive))
for tree in subresults
]
subresults = []
for stuff in newresults:
subresults.extend(stuff)
recursive = True
results.extend(subresults)
return results
def __str__(self):
return prettyprint(self.tree)
__repr__ = __str__
def get_model_tree(model):
# Get the source of the model's file
try:
source = inspect.getsource(model).replace("\r\n", "\n").replace("\r","\n") + "\n"
except IOError:
return None
tree = STTree(parser.suite(source).totuple())
# Now, we have to find it
for poss in tree.find("compound_stmt"):
if poss.value[1][0] == symbol.classdef and \
poss.value[1][2][1].lower() == model.__name__.lower():
# This is the tree
return poss
token_map = {
token.DOT: ".",
token.LPAR: "(",
token.RPAR: ")",
token.EQUAL: "=",
token.EQEQUAL: "==",
token.COMMA: ",",
token.LSQB: "[",
token.RSQB: "]",
token.AMPER: "&",
token.BACKQUOTE: "`",
token.CIRCUMFLEX: "^",
token.CIRCUMFLEXEQUAL: "^=",
token.COLON: ":",
token.DOUBLESLASH: "//",
token.DOUBLESLASHEQUAL: "//=",
token.DOUBLESTAR: "**",
token.DOUBLESLASHEQUAL: "**=",
token.GREATER: ">",
token.LESS: "<",
token.GREATEREQUAL: ">=",
token.LESSEQUAL: "<=",
token.LBRACE: "{",
token.RBRACE: "}",
token.SEMI: ";",
token.PLUS: "+",
token.MINUS: "-",
token.STAR: "*",
token.SLASH: "/",
token.VBAR: "|",
token.PERCENT: "%",
token.TILDE: "~",
token.AT: "@",
token.NOTEQUAL: "!=",
token.LEFTSHIFT: "<<",
token.RIGHTSHIFT: ">>",
token.LEFTSHIFTEQUAL: "<<=",
token.RIGHTSHIFTEQUAL: ">>=",
token.PLUSEQUAL: "+=",
token.MINEQUAL: "-=",
token.STAREQUAL: "*=",
token.SLASHEQUAL: "/=",
token.VBAREQUAL: "|=",
token.PERCENTEQUAL: "%=",
token.AMPEREQUAL: "&=",
}
def reform(bits):
"Returns the string that the list of tokens/symbols 'bits' represents"
output = ""
for bit in bits:
if bit in token_map:
output += token_map[bit]
elif bit[0] in [token.NAME, token.STRING, token.NUMBER]:
if keyword.iskeyword(bit[1]):
output += " %s " % bit[1]
else:
if bit[1] not in symbol.sym_name:
output += bit[1]
return output
def parse_arguments(argstr):
"""
Takes a string representing arguments and returns the positional and
keyword argument list and dict respectively.
All the entries in these are python source, except the dict keys.
"""
# Get the tree
tree = STTree(parser.suite(argstr).totuple())
# Initialise the lists
curr_kwd = None
args = []
kwds = {}
# Walk through, assigning things
testlists = tree.find("testlist")
for i, testlist in enumerate(testlists):
# BTW: A testlist is to the left or right of an =.
items = list(testlist.walk(recursive=False))
for j, item in enumerate(items):
if item[0] == symbol.test:
if curr_kwd:
kwds[curr_kwd] = item[1].reform()
curr_kwd = None
elif j == len(items)-1 and i != len(testlists)-1:
# Last item in a group must be a keyword, unless it's last overall
curr_kwd = item[1].reform()
else:
args.append(item[1].reform())
return args, kwds
def extract_field(tree):
# Collapses the tree and tries to parse it as a field def
bits = tree.flatten()
## Check it looks right:
# Second token should be equals
if len(bits) < 2 or bits[1] != token.EQUAL:
return
## Split into meaningful sections
name = bits[0][1]
declaration = bits[2:]
# Find the first LPAR; stuff before that is the class.
try:
lpar_at = declaration.index(token.LPAR)
except ValueError:
return
clsname = reform(declaration[:lpar_at])
# Now, inside that, find the last RPAR, and we'll take the stuff between
# them as the arguments
declaration.reverse()
rpar_at = (len(declaration) - 1) - declaration.index(token.RPAR)
declaration.reverse()
args = declaration[lpar_at+1:rpar_at]
# Now, extract the arguments as a list and dict
try:
args, kwargs = parse_arguments(reform(args))
except SyntaxError:
return
# OK, extract and reform it
return name, clsname, args, kwargs
def get_model_fields(model, m2m=False):
"""
Given a model class, will return the dict of name: field_constructor
mappings.
"""
tree = get_model_tree(model)
if tree is None:
return None
possible_field_defs = tree.find("^ > classdef > suite > stmt > simple_stmt > small_stmt > expr_stmt")
field_defs = {}
# Get aliases, ready for alias fixing (#134)
try:
aliases = aliased_models(models.get_app(model._meta.app_label))
except ImproperlyConfigured:
aliases = {}
# Go through all the found defns, and try to parse them
for pfd in possible_field_defs:
field = extract_field(pfd)
if field:
field_defs[field[0]] = field[1:]
inherited_fields = {}
# Go through all bases (that are themselves models, but not Model)
for base in model.__bases__:
if base != models.Model and issubclass(base, models.Model):
inherited_fields.update(get_model_fields(base, m2m))
# Now, go through all the fields and try to get their definition
source = model._meta.local_fields[:]
if m2m:
source += model._meta.local_many_to_many
fields = SortedDict()
for field in source:
# Get its name
fieldname = field.name
if isinstance(field, (models.related.RelatedObject, generic.GenericRel)):
continue
# Now, try to get the defn
if fieldname in field_defs:
fields[fieldname] = field_defs[fieldname]
# Try the South definition workaround?
elif hasattr(field, 'south_field_triple'):
fields[fieldname] = field.south_field_triple()
elif hasattr(field, 'south_field_definition'):
print "Your custom field %s provides the outdated south_field_definition method.\nPlease consider implementing south_field_triple too; it's more reliably evaluated." % field
fields[fieldname] = field.south_field_definition()
# Try a parent?
elif fieldname in inherited_fields:
fields[fieldname] = inherited_fields[fieldname]
# Is it a _ptr?
elif fieldname.endswith("_ptr"):
fields[fieldname] = ("models.OneToOneField", ["orm['%s.%s']" % (field.rel.to._meta.app_label, field.rel.to._meta.object_name)], {})
# Try a default for 'id'.
elif fieldname == "id":
fields[fieldname] = ("models.AutoField", [], {"primary_key": "True"})
else:
fields[fieldname] = None
# Now, try seeing if we can resolve the values of defaults, and fix aliases.
for field, defn in fields.items():
if not isinstance(defn, (list, tuple)):
continue # We don't have a defn for this one, or it's a string
# Fix aliases if we can (#134)
for i, arg in enumerate(defn[1]):
if arg in aliases:
defn[1][i] = aliases[arg]
# Fix defaults if we can
for arg, val in defn[2].items():
if arg in ['default']:
try:
# Evaluate it in a close-to-real fake model context
real_val = eval(val, __import__(model.__module__, {}, {}, ['']).__dict__, model.__dict__)
# If we can't resolve it, stick it in verbatim
except:
pass # TODO: Raise nice error here?
# Hm, OK, we got a value. Callables are not frozen (see #132, #135)
else:
if callable(real_val):
# HACK
# However, if it's datetime.now, etc., that's special
for datetime_key in datetime.datetime.__dict__.keys():
# No, you can't use __dict__.values. It's different.
dtm = getattr(datetime.datetime, datetime_key)
if real_val == dtm:
if not val.startswith("datetime.datetime"):
defn[2][arg] = "datetime." + val
break
else:
defn[2][arg] = repr(real_val)
return fields