Метка: рекурсия

super() – супер класс в Python

super() – это встроенная функция языка Python. Она возвращает прокси-объект, который делегирует вызовы методов классу-родителю (или собрату) текущего класса (или класса на выбор, если он указан, как параметр).

Основное ее применение и польза – получения доступа из класса наследника к методам класса-родителя в том случае, если наследник переопределил эти методы.

Что такое прокси-объект? Прокси, по-русски, это заместитель. То есть это объект, который по смыслу должен вести себя почти так же, как замещенный объект. Как правило он перенаправляет вызовы своих методов к другому объекту.

Давайте рассмотрим пример наследования. Есть какой-то товар в классе Base с базовой ценой в 10 единиц. Нам понадобилось сделать распродажу и скинуть цену на 20%. Хардкодить – это непрофессионально и негибко:

class Base:
    def price(self):
        return 10

class Discount(Base):
    def price(self):
        return 8

Гораздо лучше было бы получить цену из родительского класса Base и умножить ее на коэффициент 0.8, что даст 20% скидку. Однако, если мы вызовем self.price() в методе price() мы создадим бесконечную рекурсию, так как это и есть один и тот же метод класса Discount! Тут же нужен метод Base.price(). Тогда его и вызовем по имени класса:

class Discount(Base):
    def price(self):
        return Base.price(self) * 0.8

Здесь, надо не забыть указать self при вызове первым параметром явно, чтобы метод был привязан к текущему объекту. Это будет работать, но этот код не лишен изъянов, потому что необходимо явно указывать имя предка. Представьте, если иерархия классов начнет разрастаться? Например, нам нужно будет вставить между этими классами еще один класс, тогда придется редактировать имя класса-родителя в методах Discount:

class Base:
    def price(self):
        return 10

class InterFoo(Base):
    def price(self):
        return Base.price(self) * 1.1

class Discount(InterFoo):  # <-- 
    def price(self):
        return InterFoo.price(self) * 0.8  # <-- 

Тут на помощь приходит super()! Супер он не потому что, подобно Супермэну, помогает всем людям, а потому что обращается к атрибутам классов стоящих над ним в порядке наследования (кто учил матан, вспомнят понятие супремум).

Будучи вызванным без параметров внутри какого-либо класса, super() вернет прокси-объект, методы которого будут искаться только в классах, стоящих ранее, чем он, в порядке MRO. То есть, это будет как будто бы тот же самый объект, но он будет игнорировать все определения из текущего класса, обращаясь только к родительским:

class Base:
    def price(self):
        return 10

class InterFoo(Base):
    def price(self):
        return super().price() * 1.1

class Discount(InterFoo):
    def price(self):
        return super().price() * 0.8
super calls

Для Discount порядок MRO: Discount - InterFoo - Base - object. Вызов super().method() внутри класса Discount будет игнорировать Discount.method(), а будет искать method в InterFoo, затем, если не найдет, то в Base и object.

Когда нельзя забыть super?

Очень часто super вызывается в методе __init__. Метод инициализации класса __init__, как правило задает какие-либо атрибуты экземпляра класса, и если в дочернем классе мы забудем его вызвать, то класс окажется недоинициализированным: при попытке доступа к родительским атрибутам будет ошибка:

class A:
    def __init__(self):
        self.x = 10

class B(A):
    def __init__(self):
        self.y = self.x + 5

# print(B().y)  # ошибка! AttributeError: 'B' object has no attribute 'x'

# правильно:

class B(A):
    def __init__(self):
        super().__init__()  # <- не забудь!
        self.y = self.x + 5

print(B().y)  # 15

Параметры super

Функция может принимать 2 параметра. super([type [, object]]). Первый аргумент – это тип, к предкам которого мы хотим обратиться. А второй аргумент – это объект, к которому надо привязаться. Сейчас оба аргумента необязательные. В прошлых версиях Python приходилось их указывать явно:

class A:
    def __init__(self, x):
        self.x = x

class B(A):
    def __init__(self, x):
        super(B, self).__init__(x)
        # теперь это тоже самое: super().__init__(x)

Теперь Python достаточно умен, чтобы самостоятельно подставить в аргументы текущий класс и self для привязки. Но старая форма тоже осталась для особых случаев. Она нужна, если вы используете super() вне класса или хотите явно указать с какого класса хотите начать поиск методов.

Действительно, super() может быть использована вне класса. Пример:

d = Discount()
print(super(Discount, d).price())

В этом случае объект, полученный из super(), будет вести себя как класс InterFoo (родитель Discount), хотя привязан он к переменной d, которая является экземпляром класса Discount.

Это редко используется, но, вероятно, кому-то будет интересно узнать, что функция super(cls), вызванная только с одним параметром, вернет непривязанный к экземпляру объект. У него нельзя вызывать методы и обращаться к атрибутам. Привязать его можно будет так:

super_d = super(Discount)
d = Discount()
binded_d = super_d.__get__(d, Discount)  # привязка
print(binded_d.price())  # 11.0

Множественное наследование

В случае множественного наследования super() необязательно указывает на родителя текущего класса, а может указывать и на собрата. Все зависит от структуры наследования и начальной точки вызова метода. Общий принцип остается: поиск начинается с предыдущего класса в списке MRO. Давайте рассмотрим пример ромбовидного наследования. Каждый класс ниже в методе method печатает свое имя. Плюс все, кроме первого, вызывают свой super().method():

class O:
    def method(self):
        print('I am O')

class A(O):
    def method(self):
        super().method()
        print('I am A')

class B(O):
    def method(self):
        super().method()
        print('I am B')


class C(A, B):
    def method(self):
        super().method()
        print('I am C')

Если вызвать метод C().method(), то в терминале появится такая распечатка:

# C().method()
I am O
I am B
I am A
I am C

Видно, что каждый метод вызывается ровно один раз и ровно в порядке MRO. C вызывает родителя A, а A вызывает своего брата B, а B вызывает их общего родителя O. Но! Стоит нам вызвать A().method(), он уже не будет вызывать B().method(), так как класса B нет среди его родителей, он брат, а родитель у класс А только один – это O. А о братьях он и знать не хочет:

# A().method()
I am O
I am A

Таким образом, вызов super() сам автоматически догадывается, к кому обращаться: к родителю или к брату. Все зависит от иерархии класса и начальной точки вызова. Эта фишки снимают с программиста головную боль, связанную с заботой о поддержании цепочки вызовов в иерархии классов.

Специально для канала @pyway. Подписывайтесь на мой канал в Телеграм @pyway 👈 

Расстояние Левенштейна на Python

Как понять насколько близки две строки? Как поисковая система все равно находит то, что надо, даже если вы совершили пару опечаток в запросе? В этом вопросе нам поможет расстояние по Левенштейну или редакционное расстояние. Почему редакционное? Потому что мы считаем, сколько действий по редактированию одной строки нужно совершить, чтобы получить другую строку. Действия бывают следующими: вставка символа, удаление символа и замена одного символа другим.

Одинаковые строки имеют нулевое расстояние: не нужно ничего редактировать, первая и так равна второй. «Привет» и «Привт» имеют расстояние 1 (пропущена одна буква, остальное не изменилось). «Привет» и «привты» имеют расстояние 3 (одна замена «П» на «п», удаление буквы «е» и вставка «ы» на конце). Мы будем считать

Я не буду сюда копировать теорию и тем более доказательства, это вы можете изучить в Вики.

Давайте попробуем реализовать этот алгоритм в лоб по формуле:

Рекурсивная формула

Функция m – возвращает единицу, если символы не равны, иначе 0. Вот такой код получится:

def my_dist(a, b):
    def recursive(i, j):
        if i == 0 or j == 0:
            # если одна из строк пустая, то расстояние до другой строки - ее длина
            # т.е. n вставок
            return max(i, j)
        elif a[i - 1] == b[j - 1]:
            # если оба последних символов одинаковые, то съедаем их оба, не меняя расстояние
            return recursive(i - 1, j - 1)
        else:
            # иначе выбираем минимальный вариант из трех
            return 1 + min(
                recursive(i, j - 1),  # удаление
                recursive(i - 1, j),   # вставка
                recursive(i - 1, j - 1)  # замена
            )
    return recursive(len(a), len(b))

Вспомогательная функция, чтобы протестировать производительность:

from timeit import timeit

def test_lev_dist(f: callable, a, b, n=1):
    tm = timeit("f(a, b)", globals={
        'f': f, 'a': a, 'b': b
    }, number=n)
    r = f(a, b)
    print(f'a = {a!r} and b = {b!r} and {f.__name__} = {r} and time = {tm:.4f}')

test_lev_dist(my_dist, "hello world", "bye world!")
# a = 'hello world' and b = 'bye world!' and my_dist = 6 and time = 1.3829

Как можете видеть, первый вариант кода работает экстремально медленно, потому что много раз вычисляет функцию при одних и тех же входных параметрах. Давайте узнаем, сколько раз вызывается внутренняя функция. Для этого добавим декоратор, который считает число вызовов:

def count_calls(f):
    @wraps(f)
    def wrapped(*args, **kwargs):
        wrapped.n_calls += 1
        return f(*args, **kwargs)
    wrapped.n_calls = 0
    return wrapped

def my_dist(a, b):
    @count_calls
    def recursive(i, j):
        ...  # прочий код пропущен
    r = recursive(len(a), len(b))
    print('calls =', recursive.n_calls)
    return r

my_dist("hello world", "bye world!")
# calls =  3317804

Более 3 миллионов вызовов! И большинство из них с одинаковыми параметрами. А так как они повторяются, то можно их закешировать (при помощи lru_cache). Здесь размер кэша примерно равен числу различных комбинаций входных параметров.

from functools import lru_cache

def my_dist_cached(a, b):
    @count_calls
    @lru_cache(maxsize=len(a) * len(b))
    def recursive(i, j):
        if i == 0 or j == 0:
            return max(i, j)
        elif a[i - 1] == b[j - 1]:
            return recursive(i - 1, j - 1)
        else:
            return 1 + min(
                recursive(i, j - 1), 
                recursive(i - 1, j), 
                recursive(i - 1, j - 1)
            )

    r = recursive(len(a), len(b))
    print('calls = ', recursive.n_calls)
    return r

my_dist_cached("hello world", "bye world!")
# calls = 212

Количество вызовов сократилось до 212, а скорость работы увеличилась на порядки. Выкинем count_calls и проведем замеры времени для 10000 повторных вызовов.

def my_dist_cached(a, b):
    @lru_cache(maxsize=len(a) * len(b))
    def recursive(i, j):
        if i == 0 or j == 0:
            return max(i, j)
        elif a[i - 1] == b[j - 1]:
            return recursive(i - 1, j - 1)
        else:
            return 1 + min(
                recursive(i, j - 1),
                recursive(i - 1, j),
                recursive(i - 1, j - 1)
            )
    return recursive(len(a), len(b))

test_lev_dist(my_dist_cached, "hello world", "bye world!", n=10000)
# a = 'hello world' and b = 'bye world!' and my_dist_cached = 6 and time = 0.9740

Производительность улучшилась радикально (в прошлый раз мы прогоняли один вызов функции, а теперь 10000 раз, и то выходит быстрее). Однако, пока что объем требуемой памяти растет как O(len(a) * len(b)). Иными словами, для двух мегабайтных строк потребуются гигабайты памяти. Фактически в кэше будет хранится почти все матрица редактирований, а она нам не нужна целиком. Наша цель – правый нижний элемент.

Матрица редактирований recursive(i, j)
Матрица редактирований recursive(i, j)

Для его поиска можно обойтись лишь парой рядов: текущим и предыдущим. А остальные ряды не хранить в памяти. Так мы дойдем до конца таблицы, и нижний правый угол и будет искомым значением.

Вот пример реализации:

def distance(a, b):
    n, m = len(a), len(b)
    if n > m:
        # убедимся что n <= m, чтобы использовать минимум памяти O(min(n, m))
        a, b = b, a
        n, m = m, n

    current_row = range(n + 1)  # 0 ряд - просто восходящая последовательность (одни вставки)
    for i in range(1, m + 1):
        previous_row, current_row = current_row, [i] + [0] * n
        for j in range(1, n + 1):
            add, delete, change = previous_row[j] + 1, current_row[j - 1] + 1, previous_row[j - 1]
            if a[j - 1] != b[i - 1]:
                change += 1
            current_row[j] = min(add, delete, change)

    return current_row[n]

Объяснение. Сначала, чтобы использовать еще меньше памяти, мы можем поменять местами наши строки, чтобы длина рядов была минимальна. Это существенно экономит память, если одна из строк длинная, а другая короткая.

Затем мы понимаем, что нулевой ряд или столбец матрицы – просто восходящая последовательность. Потому что, чтобы из пустой строки получить любую не пустую, нужно ровно то число вставок, сколько символов в не пустой строке. И наоборот: n удалений из строки длины n приведут неизбежно к пустой строке.

Нам достаточно пары рядов.
Тут на картинке не ряды, а столбцы, но смысла это не меняет.

Потом мы шагаем по рядам матрицы, помня только текущий ряд и предыдущий, мы заполняем неизвестные клетки текущего ряда. Соседние клетки отвечают за вставку одного символа, удаление или замену (если символы неравны). Из трех возможных изменений мы выбираем то, чья стоимость наименьшая.

Эта версия еще быстрее, чем кэшированная:

 test_lev_dist(distance, "hello world", "bye world!", n=10000)
# a = 'hello world' and b = 'bye world!' and distance = 6 and time = 0.7374

Сложность этого алгоритма растет как произведение длин строк: O(n*m). Это еще не самый эффективный алгоритм. Для дальнейшего ускорения нужно воспользоваться древовидной структурой данных. Также неплохо бы учесть то, что на известном словаре можно заранее вычислить расстояния между словами.

Наконец-то, когда мы разобрались с принципом работы алгоритма, вспомним, что все велосипеды уже написаны до нас, да еще и на языке Си. Воспользуемся библиотечными функциями, установив пакет:

 pip install python-Levenshtein
import Levenshtein

test_lev_dist(Levenshtein.distance, "hello world", "bye world!", n=10000)
# a = 'hello world' and b = 'bye world!' and distance = 6 and time = 0.0032

Специально для канала @pyway. Подписывайтесь на мой канал в Телеграм @pyway 👈