aboutsummaryrefslogtreecommitdiff
blob: 0648de1ef7c3e5217efd4c7c31e4f185303d2e9a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import errno
import inspect
import os
import stat
import sys

from ..compatibility import IGNORED_EXCEPTIONS


class PythonNamespaceWalker:

    ignore_all_import_failures = False

    valid_inits = frozenset(f"__init__.{x}" for x in ("py", "pyc", "pyo", "so"))

    # This is for py3.2/PEP3149; dso's now have the interp + major/minor embedded
    # in the name.
    # TODO: update this for pypy's naming
    abi_target = "cpython-%i%i" % tuple(sys.version_info[:2])

    module_blacklist = frozenset(
        {
            "snakeoil.cli.arghparse",
            "snakeoil.pickling",
        }
    )

    def _default_module_blacklister(self, target):
        return target in self.module_blacklist or target.startswith("snakeoil.dist")

    def walk_namespace(self, namespace, **kwds):
        location = os.path.abspath(
            os.path.dirname(self.poor_mans_load(namespace).__file__)
        )
        return self.get_modules(self.recurse(location), namespace=namespace, **kwds)

    def get_modules(
        self, feed, namespace=None, blacklist_func=None, ignore_failed_imports=None
    ):
        if ignore_failed_imports is None:
            ignore_failed_imports = self.ignore_all_import_failures
        if namespace is None:
            mangle = lambda x: x
        else:
            orig_namespace = namespace
            mangle = lambda x: f"{orig_namespace}.{x}"
        if blacklist_func is None:
            blacklist_func = self._default_module_blacklister
        for mod_name in feed:
            try:
                if mod_name is None:
                    if namespace is None:
                        continue
                else:
                    namespace = mangle(mod_name)
                if blacklist_func(namespace):
                    continue
                yield self.poor_mans_load(namespace)
            except ImportError:
                if not ignore_failed_imports:
                    raise

    def recurse(self, location, valid_namespace=True):
        if os.path.dirname(location) == "__pycache__":
            # Shouldn't be possible, but make sure we avoid this if it manages
            # to occur.
            return
        l = os.listdir(location)
        if not self.valid_inits.intersection(l):
            if valid_namespace:
                return
        else:
            yield None

        stats: list[tuple[str, int]] = []
        for x in l:
            try:
                stats.append((x, os.stat(os.path.join(location, x)).st_mode))
            except OSError as exc:
                if exc.errno != errno.ENOENT:
                    raise
                # file disappeared under our feet... lock file from
                # trial can cause this.  ignore.
                import logging

                logging.debug(
                    "file %r disappeared under our feet, ignoring",
                    os.path.join(location, x),
                )

        seen = set(["__init__"])
        for x, st in stats:
            if not (x.startswith(".") or x.endswith("~")) and stat.S_ISREG(st):
                if x.endswith((".py", ".pyc", ".pyo", ".so")):
                    y = x.rsplit(".", 1)[0]
                    # Ensure we're not looking at a >=py3k .so which injects
                    # the version name in...
                    if y not in seen:
                        if "." in y and x.endswith(".so"):
                            y, abi = x.rsplit(".", 1)
                            if abi != self.abi_target:
                                continue
                        seen.add(y)
                        yield y

        for x, st in stats:
            if stat.S_ISDIR(st):
                for y in self.recurse(os.path.join(location, x)):
                    if y is None:
                        yield x
                    else:
                        yield f"{x}.{y}"

    @staticmethod
    def poor_mans_load(namespace, existence_check=False):
        try:
            obj = __import__(namespace)
            if existence_check:
                return True
        except:
            if existence_check:
                return False
            raise
        for chunk in namespace.split(".")[1:]:
            try:
                obj = getattr(obj, chunk)
            except IGNORED_EXCEPTIONS:
                raise
            except AttributeError:
                raise AssertionError(f"failed importing target {namespace}")
            except Exception as e:
                raise AssertionError(f"failed importing target {namespace}; error {e}")
        return obj


class TargetedNamespaceWalker(PythonNamespaceWalker):
    target_namespace = None

    def load_namespaces(self, namespace=None):
        if namespace is None:
            namespace = self.target_namespace
        for _mod in self.walk_namespace(namespace):
            pass


class _classWalker:

    cls_blacklist = frozenset()

    def is_blacklisted(self, cls):
        return cls.__name__ in self.cls_blacklist

    def test_object_derivatives(self, *args, **kwds):
        # first load all namespaces...
        self.load_namespaces()

        # next walk all derivatives of object
        for cls in self.walk_derivatives(object, *args, **kwds):
            if not self._should_ignore(cls):
                self.run_check(cls)

    def iter_builtin_targets(self):
        for attr in dir(__builtins__):
            obj = getattr(__builtins__, attr)
            if not inspect.isclass(obj):
                continue
            yield obj

    def test_builtin_derivatives(self, *args, **kwds):
        self.load_namespaces()
        for obj in self.iter_builtin_targets():
            for cls in self.walk_derivatives(obj, *args, **kwds):
                if not self._should_ignore(cls):
                    self.run_check(cls)

    def walk_derivatives(self, obj):
        raise NotImplementedError(self.__class__, "walk_derivatives")

    def run_check(self, cls):
        raise NotImplementedError


class SubclassWalker(_classWalker):
    def walk_derivatives(self, cls, seen=None):
        if len(inspect.signature(cls.__subclasses__).parameters) != 0:
            return
        if seen is None:
            seen = set()
        pos = 0
        for pos, subcls in enumerate(cls.__subclasses__()):
            if subcls in seen:
                continue
            seen.add(subcls)
            if self.is_blacklisted(subcls):
                continue
            for grand_daddy in self.walk_derivatives(subcls, seen):
                yield grand_daddy
        if pos == 0:
            yield cls


class KlassWalker(_classWalker):
    def walk_derivatives(self, cls, seen=None):
        if len(inspect.signature(cls.__subclasses__).parameters) != 0:
            return

        if seen is None:
            seen = set()
        elif cls not in seen:
            seen.add(cls)
            yield cls

        for subcls in cls.__subclasses__():
            if subcls in seen:
                continue
            for node in self.walk_derivatives(subcls, seen=seen):
                yield node