python – Overloading a function with default arguments

Question:

I have a function that, depending on the bool argument, returns one type or another, like this:

from typing import Union
def foo(num: int = 1, as_bytes: bool = False) -> Union[str, bytes]:
    rv = str(num)
    return rv.encode("utf-8") if as_bytes else rv

foo(1, False)  # Revealed type is 'Union[builtins.str, builtins.bytes]'
foo(as_bytes=True)  # Revealed type is 'Union[builtins.str, builtins.bytes]'

I want to write annotations for function overloading so that mypy understands which type is returned depending on the value of the argument:

foo(1, False)  # Revealed type is 'builtins.str'
foo(as_bytes=True)  # Revealed type is 'builtins.bytes'

The best thing I have been able to write is this:

from typing import Union, Literal, overload
@overload
def foo(num: int = 1, as_bytes: Literal[False] = False) -> str: ...
@overload
def foo(num: int = 1, as_bytes: Literal[True] = True) -> bytes: ...
@overload
def foo(num: int = 1, as_bytes: bool = False) -> Union[str, bytes]: ...

But the problem is that mypy in this case swears by the error "Overloaded function signatures 1 and 2 overlap with incompatible return types". And this is quite logical, because the call to foo() will match two different signatures. However, I still haven't figured out how to fix this. If you try to write as_bytes: Literal[True] = False , then it swears with the error "Incompatible default for argument", which is also logical.

How to write these annotations correctly, if at all possible?

Answer:

It turned out such an option

from typing import Union, Literal, overload


@overload
def foo(num: int = 1) -> str:  ...
@overload
def foo(*, as_bytes: Literal[True]) -> bytes:  ...
@overload
def foo(*, as_bytes: Literal[False]) -> str:  ...
@overload
def foo(num: int, as_bytes: Literal[True]) -> bytes: ...
@overload
def foo(num: int, as_bytes: Literal[False]) -> str: ...


def foo(num: int = 1, as_bytes: bool = False) -> Union[str, bytes]:
    rv = str(num)
    return rv.encode("utf-8") if as_bytes else rv



s = foo(1, False)
b = foo(as_bytes=True)
s2 = foo(1)

does not seem to give errors and correctly detects all types

Scroll to Top
AllEscort