Published: January 1, 2019
by Tobias Pleyer
Tags: python

Function Overloading in Python

Motivation

An overloaded function exhibits a different behavior depending on the arguments it is called with. This can be very useful if you want to maintain a consistent interface, even though the function may have different arity or argument types. For example see the following example in pseudo code:

Since Python is a dynamic language it is quite possible that the arguments have different types at runtime. The above example is slightly nicer to read than the more explicit version:

Python does not support function overloading out of the box and there are certainly reasons to not use it at all. I won’t discuss the pros and cons of function overloading, the goal of this blog post is simply to show that there is a way to implement function overloading in an elegant, maintainable and readable way.

Implementation

Hand-written

Obviously at some point we have to decide which version of a function we need to invoke, so the simplest possible implementation is a hand-written if-statement.

But this implementation is very fragile! For every possible combination of arguments we have to make sure to modify the if-statement.

If the actual variants of the function are scattered throughout the source code it is very easy to miss one. Or what happens if another developer mindlessly renames one of the functions?

Type Annotations and Decorators

Since Python3.5 it is possible to annotate functions with optional type declarations. These declarations have no semantic meaning for Python’s interpreter but can be used by third party tools for their analysis. If a function is annotated with types these are part of the function object:

Since the annotations are part of the function object they are also available to function decorators, which do receive the function object for modification. With this knowledge it is possible to hack function overloads via function decorators. Once a function has been decorated no manual maintenance is required anymore. Below follows my proof of concept implementation for simple function overloading.

from typing import Any, Callable, Tuple


class OverloadException(Exception):
    pass


class OverloadedFunction:
    def __init__(self, name):
        self._name = name
        self._overloads = {}

    def add_overload(self, signature: Tuple[Any, ...], func: Callable[[Any], Any]):
        if signature in self._overloads:
            raise OverloadException(f"Overloaded function '{self._name}' has already been overloaded with signature {signature}")
        self._overloads[signature] = func

    def __call__(self, *args):
        signature = tuple(map(type, args))
        func = self._overloads.get(signature, None)
        if func is None:
            raise OverloadException(f"Overloaded function '{self._name}' does not provide an overload for the signature {signature}")
        else:
            return func(*args)


def overload(overloaded_funcname: str):
    def decorator(func: Callable[[Any], Any])-> Callable[[Any], Any]:
        if overloaded_funcname not in globals():
            globals()[overloaded_funcname] = OverloadedFunction(overloaded_funcname)
        overloaded_func = globals()[overloaded_funcname]
        if not isinstance(overloaded_func, OverloadedFunction):
            raise OverloadException("Given function name does not correspond to an overloaded function")
        signature = func.__annotations__
        if 'return' in signature:
            signature.pop('return')
        signature = tuple(signature.values())
        overloaded_func.add_overload(signature, func)
        return func
    return decorator


if __name__ == '__main__':

    @overload('test')
    def test1(x: int):
        return x + 42

    @overload('test')
    def test2(x: float):
        return x * 0.5

    @overload('test')
    def test3(x: str):
        return x + "!!!"

    @overload('test')
    def test4(n: int, x: str):
        return x + "!"*n

    print(test(2))
    print(test(2.0))
    print(test("Hi"))
    print(test(5, "Hi"))

    customer_hash = { "John Doe": {"first": "John", "last": "Doe", "age": 42} }
    customer_array = 123456*[{}] + [{"first": "John", "last": "Doe", "age": 42}]

    @overload('lookup')
    def lookupByString(s: str):
        return customer_hash[s]

    @overload('lookup')
    def lookupByIndex(i: int):
        return customer_array[i]

    print(lookup("John Doe"))
    print(lookup(123456))