import ast
import collections
import unittest
import warnings
from contextlib import contextmanager
from cached_property import cached_property
__all__ = ['find_globals', 'FindFreeVariablesClassWarning']
[docs]class FindFreeVariablesClassWarning(UserWarning):
pass
class FindFreeVariables(ast.NodeVisitor):
''' use the `find_globals` utility function '''
@cached_property
def ffv_scope(self):
return collections.ChainMap()
@cached_property
def ffv_globals(self):
return set()
# possible scopes: local, global, force_global
@property
def ffv_new_scope(self):
scope = self.ffv_scope
@contextmanager
def context():
scope.maps.insert(0, dict())
yield
fglobals = self.ffv_globals
for name, var_scope in scope.maps[0].items():
if var_scope != 'local':
fglobals.add(name)
del scope.maps[0]
return context
def ffv_binding(self, name, state='local'):
d = self.ffv_scope
cur = d.get(name, 'global')
if cur != 'force_global':
if not (cur == 'local' and state == 'global'):
d[name] = state
def visit_ListComp(self, node):
with self.ffv_new_scope():
super().generic_visit(node)
def visit_SetComp(self, node):
self.visit_ListComp(node)
def visit_DictComp(self, node):
self.visit_ListComp(node)
def visit_GeneratorExp(self, node):
self.visit_ListComp(node)
def visit_Name(self, node):
if isinstance(node.ctx, ast.Store): # establish new binding
self.ffv_binding(node.id)
elif isinstance(node.ctx, ast.Load):
self.ffv_binding(node.id, state='global')
super().generic_visit(node)
def visit_FunctionDef(self, node):
with self.ffv_new_scope():
args = node.args
for a in args.args:
self.ffv_binding(a.arg)
for a in args.kwonlyargs:
self.ffv_binding(a.arg)
if args.vararg:
self.ffv_binding(args.vararg.arg)
if args.kwarg:
self.ffv_binding(args.kwarg.arg)
super().generic_visit(node)
name = getattr(node, 'name', None)
if name: self.ffv_binding(node.name)
def visit_AsyncFunctionDef(self, node):
self.visit_FunctionDef(node)
def visit_Lambda(self, node):
self.visit_FunctionDef(node)
def visit_ClassDef(self, node):
warnings.warn(FindFreeVariablesClassWarning(
"Global variable detection is not fully implemented for Python "
"classes. This may lead to false dependencies being established "
"in the caching system."))
super().generic_visit(node)
def visit_ExceptHandler(self, node):
if node.name:
self.ffv_binding(node.name)
super().generic_visit(node)
def visit_Import(self, node):
for alias in node.names:
self.ffv_binding(alias.asname)
super().generic_visit(node)
def visit_ImportFrom(self, node):
self.visit_Import(node)
def visit_Global(self, node):
for name in node.names:
self.ffv_binding(name, 'force_global')
super().generic_visit(node)
def visit_With(self, node):
super().generic_visit(node)
[docs]def find_globals(ast_tree):
'''
Returns a set of all global variable names referenced in the code.
Warning: does not yet work properly on classes.'''
ffv = FindFreeVariables()
ffv.visit(ast_tree)
return ffv.ffv_globals
class UnitTest(unittest.TestCase):
def test_me(self):
codes = [('''\
def f(a, b):
l1 = a + b
g1(g2, l1)
''', 'g1 g2'),
('''\
def f(a):
l1, g1.abc, l2 = a + g2
if a:
l3 = g3(l2, name=l1, *l3, **g4)
g5.A2[l3] = g6.A1[l2]
''', 'g1 g2 g3 g4 g5 g6'),
('''\
def f(a):
for i in a:
g1(a, i)
l1 = g2[g3(i)]
with g4(l1) as l2:
g1(l1, l2)
def h(b, arg=g5):
g6(l1, l2, i)
k = lambda b, arg=g7: g8(l1, b)
''', 'g1 g2 g3 g4 g5 g6 g7 g8'),
('''\
def f(a):
try:
l1 = g1()
except g4 as l2:
l1 = g2(l2)
finally:
g3()
''', 'g1 g2 g3 g4'),
('''\
def f(a):
import abc as l1
from abc import xyz as l2
g1(l1, l2)
''', 'g1'),
('''\
def f(a):
global g1
g1, l1 = a
''', 'g1'),
('''\
def f(a):
with g1(a, g2) as l1:
g2(l1)
g3(l1, a)
''', 'g1 g2 g3'),
('''\
def f(a):
g1(l1+g2 for l1 in a+g3)
g4([l2+g5 for l2 in a+g6])
''', 'g1 g2 g3 g4 g5 g6'),
('''\
def f(a):
def l1():
return a
return l1
''', ''),
('''\
def f1(g1):
def f2():
global g1
def f3():
g1()
''', 'g1'),
]
for code, gs in codes:
c = ast.parse(code)
self.assertEqual(set(gs.split()), find_globals(c))
def test_raises_warning(self):
c = ast.parse('''
def f(a):
class l1(g1, metaclass=g2):
l2 = a
g3 = a # here it's local to the class
def method(self):
g3() # here this is actually a global!
''')
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
find_globals(c)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(
w[0].category, FindFreeVariablesClassWarning))