# Copyright (c) <2015-2016>, Eric Dill
#
# All rights reserved. Redistribution and use in source and binary forms, with
# or without modification, are permitted provided that the following conditions
# are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
from __future__ import print_function, division, absolute_import
import ast
import logging
import os
import sys
from collections import defaultdict
from typing import Union
from .stdliblist import builtin_modules
from .utils import (
AST_QUESTIONABLE,
namespace_packages,
SKETCHY_TYPES_TABLE,
)
logger = logging.getLogger('depfinder')
PACKAGE_NAME = None
STRICT_CHECKING = False
def get_top_level_import_name(name, custom_namespaces=None):
num_dot = name.count(".")
custom_namespaces = custom_namespaces or []
if name in namespace_packages or name in custom_namespaces or name in builtin_modules:
return name
elif any(
((num_dot - nsp.count(".")) == 1) and name.startswith(nsp + ".")
for nsp in custom_namespaces
):
# this branch happens when name is foo.bar.baz and the namespace is
# foo.bar
return name
else:
if '.' not in name:
return name
else:
return get_top_level_import_name(
name.rsplit('.', 1)[0],
custom_namespaces=custom_namespaces
)
class ImportFinder(ast.NodeVisitor):
"""Find all imports in an Abstract Syntax Tree (AST).
Attributes
----------
required_modules : list
The list of imports that were found outside of try/except blocks,
function definitions and class definitions
sketchy_modules : list
The list of imports that were found inside of try/except blocks,
function definitions and class definitions
imports : list
The list of all ast.Import nodes in the AST
import_froms : list
The list of all ast.ImportFrom nodes in the AST
"""
def __init__(self, filename='', custom_namespaces=None):
self.filename = filename
self.required_modules = set()
self.sketchy_modules = set()
self.builtin_modules = set()
self.relative_modules = set()
self.imports = []
self.import_froms = []
self.total_imports = defaultdict(dict)
self.sketchy_nodes = {}
self.custom_namespaces = custom_namespaces or []
super(ImportFinder, self).__init__()
def visit(self, node):
"""Recursively visit all ast nodes.
Look for Import and ImportFrom nodes. Classify them as being imports
that are built in, relative, required or questionable. Questionable
imports are those that occur within the context of a try/except block, a
function definition or a class definition.
Parameters
----------
node : ast.AST
The node to start the recursion
"""
# add the node to the try/except block to signify that
# something potentially odd is going on in this import
if isinstance(node, AST_QUESTIONABLE):
self.sketchy_nodes[node] = node
super(ImportFinder, self).visit(node)
# after the node has been recursed in to, remove the try node
self.sketchy_nodes.pop(node, None)
def visit_Import(self, node: ast.Import):
"""Executes when an ast.Import node is encountered
an ast.Import node is something like 'import bar'
If ImportCatcher is inside of a try block then the import that has just
been encountered will be added to the `sketchy_modules` instance
attribute. Otherwise the module will be added to the `required_modules`
instance attribute
"""
self.imports.append(node)
self._add_to_total_imports(node)
mods = set([
get_top_level_import_name(name.name, custom_namespaces=self.custom_namespaces)
for name in node.names
])
for mod in mods:
self._add_import_node(mod)
def visit_ImportFrom(self, node: ast.ImportFrom):
"""Executes when an ast.ImportFrom node is encountered
an ast.ImportFrom node is something like 'from foo import bar'
If ImportCatcher is inside of a try block then the import that has just
been encountered will be added to the `sketchy_modules` instance
attribute. Otherwise the module will be added to the `required_modules`
instance attribute
"""
self.import_froms.append(node)
if node.module is None:
# this is a relative import like 'from . import bar'
# so do nothing
return
if node.level > 0:
# this is a relative import like 'from .foo import bar'
node_name = get_top_level_import_name(
node.module, custom_namespaces=self.custom_namespaces
)
self.relative_modules.add(node_name)
return
# this is a non-relative import like 'from foo import bar'
self._add_to_total_imports(node)
node_name = get_top_level_import_name(
node.module,
custom_namespaces=self.custom_namespaces,
)
self._add_import_node(node_name)
def _add_to_total_imports(self, node: Union[ast.Import, ast.ImportFrom]):
import_metadata = {}
try:
import_metadata.update({'exact_line': ast.unparse(node)})
except AttributeError:
pass
import_metadata.update({v: False for v in SKETCHY_TYPES_TABLE.values()})
import_metadata.update({SKETCHY_TYPES_TABLE[node.__class__]: True for node in self.sketchy_nodes})
names = set()
if isinstance(node, ast.Import):
_names = set(name.name for name in node.names)
import_metadata['import'] = _names
names.update(_names)
elif isinstance(node, ast.ImportFrom):
import_metadata['import_from'] = {node.module}
names.add(node.module)
else:
raise NotImplementedError(f"Expected ast.Import or ast.ImportFrom this is {type(node)}")
for name in names:
self.total_imports[name].update({(self.filename, node.lineno): import_metadata})
def _add_import_node(self, node_name):
# see if the module is a builtin
if node_name in builtin_modules:
self.builtin_modules.add(node_name)
return
# see if we are in a try block
if self.sketchy_nodes:
self.sketchy_modules.add(node_name)
return
# if none of the above cases are true, it is likely that this
# ImportFrom node occurs at the top level of the module
self.required_modules.add(node_name)
def describe(self):
"""Return the found imports
Returns
-------
dict :
'required': The modules that were encountered outside of a
try/except block
'questionable': The modules that were encountered inside of a
try/except block
'relative': The modules that were imported via relative import
syntax
'builtin' : The modules that are part of the standard library
"""
desc = {
'required': self.required_modules,
'relative': self.relative_modules,
'questionable': self.sketchy_modules,
'builtin': self.builtin_modules
}
desc = {k: v for k, v in desc.items() if v}
return desc
def __repr__(self):
return 'ImportCatcher: %s' % repr(self.describe())
[docs]
def get_imported_libs(code, filename='', custom_namespaces=None):
"""Given a code snippet, return a list of the imported libraries
Parameters
----------
code : str
The code to parse and look for imports
Returns
-------
ImportCatcher
The ImportCatcher is the object in `depfinder` that contains all the
information regarding which imports were found where. You will most
likely be interested in calling the describe() function on this return
value.
Examples
--------
>>> depfinder.get_imported_libs('from foo import bar')
{'required': {'foo'}, 'questionable': set()}
>>> with open('depfinder.py') as f:
code = f.read()
imports = depfinder.get_imported_libs(code)
print(imports.describe())
{'builtin': {'__future__', 'json', 'ast', 'os', 'sys', 'collections'},
'required': {'stdlib_list'}}
"""
# skip ipython notebook lines
code = '\n'.join([line for line in code.split('\n')
if not line.startswith('%')])
tree = ast.parse(code)
import_finder = ImportFinder(filename=filename, custom_namespaces=custom_namespaces)
import_finder.visit(tree)
return import_finder
def parse_file(python_file, custom_namespaces=None):
"""Parse a single python file
Parameters
----------
python_file : str
Path to the python file to parse for imports
Returns
-------
catchers : tuple
Yields tuples of (module_name, full_path_to_module, ImportCatcher)
"""
global PACKAGE_NAME
if PACKAGE_NAME is None:
PACKAGE_NAME = os.path.basename(python_file).split('.')[0]
logger.debug("Setting PACKAGE_NAME global variable to {}"
"".format(PACKAGE_NAME))
# Try except block added for adal package which has a BOM at the beginning,
# requiring a different encoding to load properly
try:
with open(python_file, 'r') as f:
code = f.read()
catcher = get_imported_libs(
code, filename=python_file, custom_namespaces=custom_namespaces
)
except SyntaxError:
with open(python_file, 'r', encoding='utf-8-sig') as f:
code = f.read()
catcher = get_imported_libs(
code, filename=python_file, custom_namespaces=custom_namespaces
)
catcher.total_imports = dict(catcher.total_imports)
mod_name = os.path.split(python_file)[:-3]
return mod_name, python_file, catcher
[docs]
def iterate_over_library(path_to_source_code, custom_namespaces=None):
"""Helper function to recurse into a library and find imports in .py files.
This allows the user to apply filters on the user-side to exclude imports
based on their file names.
`conda-skeletor <https://github.com/ericdill/conda-skeletor>`_
makes heavy use of this function
Parameters
----------
path_to_source_code : str
Yields
-------
catchers : tuple
Yields tuples of (module_name, full_path_to_module, ImportCatcher)
"""
global PACKAGE_NAME
global STRICT_CHECKING
if PACKAGE_NAME is None:
PACKAGE_NAME = os.path.basename(path_to_source_code).split('.')[0]
logger.debug("Setting PACKAGE_NAME global variable to {}"
"".format(PACKAGE_NAME))
skipped_files = []
all_files = []
for parent, folders, files in os.walk(path_to_source_code):
for f in files:
if f.endswith('.py'):
full_file_path = os.path.join(parent, f)
all_files.append(full_file_path)
try:
yield parse_file(full_file_path, custom_namespaces=custom_namespaces)
except Exception:
logger.exception("Could not parse file: {}".format(full_file_path))
skipped_files.append(full_file_path)
if skipped_files:
logger.warning("Skipped {}/{} files".format(len(skipped_files), len(all_files)))
for idx, f in enumerate(skipped_files):
logger.warn("%s: %s" % (str(idx), f))
if skipped_files and STRICT_CHECKING:
raise RuntimeError("Some files failed to parse. See logs for full stack traces.")