Do-notation for Python

Recap

In a previous post I introduced monadic computations for Python. In this post I extended the example of the original post. Make sure you read these posts first before continuing with this one.

In the closing words of the original post I admitted that the given example code looked a tiny but ugly and unreadable and explained the lack of expressiveness with the lack of support for do-notation in Python. The very last coding example of the aforementioned blog post showed the example written in Haskell. The fact that Haskell’s do-notation spares the programmer from exessive use of parentheses and other syntactic clutter is a real win.

Do-notation in Haskell is actually just syntactic sugar. A Haskell compiler will rewrite computations given in do-notation using lambda functions and the monad’s bind (>>=) function.

So you can ask: If Haskell do-notation is just a facade, so to speak, what should stop us from doing the same in Python? Nothing as it turns out. Of course we don’t want to modify Python’s compiler, that would be quite a lot of work, but we can write a function, do, that will do the trick (pun intended).

Do

Understanding do-notation

Before we can start writing our do function we have to understand how do-notation actually works.

In total we have to deal with 4 cases:

  • A monadic computation with the result being ignored (sequence)
  • A monadic computation with the result not being ignored (bind)
  • A let-expression
  • A nested do-notation

Here is an example that makes use of all these cases:

#!/usr/bin/env stack
{- stack script
   --resolver lts-11.8
   --package base
-}

-- File do_demo.hs

main = do
  putStr "How many gos do you want? "
  goCnt <- do
    line <- getLine
    let cnt = read line :: Int
    return cnt
  sequence_ $ replicate goCnt (putStr "go")
  putStrLn "!"
$ sudo chmod +x do_demo.hs
$ ./do_demo.hs
How many gos do you want? 3
gogogo!

And here is the de-sugared (rewritten) version

#!/usr/bin/env stack
{- stack script
   --resolver lts-11.8
   --package base
-}

-- File do_demo_desugared.hs

main =
  putStr "How many gos do you want? " >>
    (
    getLine >>= \line ->
    return (read line :: Int) >>= \cnt ->
    return cnt
    ) >>= \goCnt ->
  (sequence_ $ replicate goCnt (putStr "go")) >>
  putStrLn "!"

A lot harder to read, right? That’s the reason do-notation was introduced: To make monadic code nicer to write and read.

But as the simple example shows rewriting do-notation is not a hard thing to do, it is very mechanical. But this means computers are really good in doing it. Equipped with this knowledge we can start writing our do-function.

Test Driven Developement

But before we start coding away we want to follow good coding pratices. One of them is to install a test suite before any code is written, and then start coding with respect to the constraints of the given tests. In this way we get live feedback about the quality and correctness of our implementation. This technique is called Test-Driven-Development, TDD in short.

Here is our Python script, with an empty do-function and the test case to test it:

import unittest


def do(s: str) -> str:
    """The do-function.
    Takes a string containing the computation written in Haskell style
    do-notation and returns a string with the computation rewritten using the
    bind (>>=) and sequence (>>) functions of the monad type class.

    This function will render the whole computation in one line without
    indentation.

    The returned string should be valid to be passed to Python's `eval`.
    """
    pass


class DoTest(unittest.TestCase):

    def test_empty(self):
        self.assertEqual(do(""), "")

    def test_oneline_seq_is_ok(self):
        test_input = 'putStrLn "Test"'
        expected = '(putStrLn "Test")'
        self.assertEqual(do(test_input), expected)

    def test_oneline_bind_fails(self):
        test_input = 'line <- getLine'
        with self.assertRaises(ValueError):
            do(test_input)

    def test_oneline_let_fails(self):
        test_input = 'let x = 42'
        with self.assertRaises(ValueError):
            do(test_input)

    def test_sequence_only(self):
        test_input = """
        putStrLn "Test"
        putStrLn "Test2"
        stackReturn ()
        """
        expected = (
                '(putStrLn "Test")' + ' >> (' +
                '(putStrLn "Test2")' + ' >> (' +
                '(stackReturn ())' +
                '))'
                )
        self.assertEqual(do(test_input), expected)

    def test_one_bind(self):
        test_input = """
        line <- getLine
        putStrLn line
        """
        expected = (
                '(getLine)' + ' | (lambda line: ' +
                '(putStrLn line)' +
                ')'
                )

    def test_two_binds(self):
        test_input = """
        line <- getLine
        line2 <- getLine
        putStrLn line
        putStrLn line2
        """
        expected = (
                '(getLine)' + ' | (lambda line: ' +
                '(getLine)' + ' | (lambda line2: ' +
                '(putStrLn line)' + ' >> (' +
                '(putStrLn line2)' +
                ')))'
                )
        self.assertEqual(do(test_input), expected)

    def test_simple_let(self):
        test_input = """
        let x = 42
        putStrLn (show x)
        """
        expected = (
                '(stackReturn (42))' + ' | (lambda x: ' +
                '(putStrLn (show x))' +
                ')'
                )
        self.assertEqual(do(test_input), expected)

    def test_ignore_empty_line(self):
        test_input = """
        putStrLn "foo"

        putStrLn "bar"
        """
        expected = (
                '(putStrLn "foo")' + ' >> (' +
                '(putStrLn "bar")' +
                ')'
                )
        self.assertEqual(do(test_input), expected)

    def test_ignore_comments(self):
        test_input = """
        # first comment
        res <- complicatedComputation(arg1, arg2)
        # second comment
        shell("more stuff", res)
        """
        expected = (
                '(complicatedComputation(arg1, arg2))' + ' | (lambda res: ' +
                '(shell("more stuff", res))' +
                ')'
                )
        self.assertEqual(do(test_input), expected)

    def test_example(self):
        test_input = """
        putStr "How many gos do you want? "
        line <- getLine
        let cnt = read line :: Int
        sequence_ $ replicate goCnt (putStr "go")
        putStrLn "!"
        """
        expected = (
                '(putStr "How many gos do you want? ")' + ' >> (' +
                '(getLine)' + ' | (lambda line: ' +
                '(stackReturn (read line :: Int))' + ' | (lambda cnt: ' +
                '(sequence_ $ replicate goCnt (putStr "go"))' + ' >> (' +
                '(putStrLn "!")' +
                '))))'
                )
        self.assertEqual(do(test_input), expected)


if __name__ == '__main__':
    unittest.main()

As can be seen from the expectations, the output of do will be overly parenthesized. This doesn’t hurt the reader/programmer, because they will never see that code, but makes deterministic rewriting much easier, because we don’t have to think about precedence etc.

Writing the do-function

If you look at the tests the expectation values are written very verbose. This has three reasons:

  • It is easy to connect it visually to the input
  • It is easier to modify the test quickly
  • It already gives hint for the solution strategy

This last point is of special interest to us, because we want to implement the function now. To make things easy we assume the following:

  • The do keyword has to be the last word of the line (nothing follows)

We will follow this strategy, for now ignoring nested do:

  1. Define a class for a do-line
  2. Create rewrite rules for the do-line class
  3. Write a parser for a let expression
  4. Write a parser for a monadic computation with ignored result
  5. Write a parser for a monadic computation with the result bound to a variable
  6. Write a parser that can parse a line, no matter if let, sequence or bind
  7. Our do-function then has to
    • Split the input in lines
    • Filter out comments and empty lines
    • Parse the lines to do-lines
    • Make some assertions (e.g. no binds at the end)
    • Desugar (rewrite) the do-lines
    • Merge all strings to a final string
    • Balance the open parentheses with closing ones

This should give us a string that can be evaluated with Python’s eval function. The rest is then handled by the explicit monadic code.

Our parsing framework will be pyparsing. Pyparsing provides a nice API to build up parsers. I already compared pyparsing with attoparsec in this post.

The do-line class

The following class will be used to represent a basic do-line and check off points 1 and 2.

from enum import IntEnum
from typing import List

class DoLine:
    class Type(IntEnum):
        LET = 1
        SEQ = 2
        BIND = 3

    @classmethod
    def type2Str(cls, typ):
        s = ""
        if typ == cls.Type.LET:
            s = "LET"
        elif typ == cls.Type.SEQ:
            s = "SEQ"
        elif typ == cls.Type.BIND:
            s = "BIND"
        return s

    def __init__(self, typ, val: str, var:str =""):
        self.typ = typ
        self.val = val.strip()
        self.var = var.strip()

    def __str__(self):
        return (self.type2Str(self.typ) + " " + self.var + ": " + self.val)

    def __repr__(self):
        return self.__str__()

    def __eq__(self, other):
        return ((self.typ == other.typ) and
                (self.val == self.val) and
                (self.var == self.var))

    def rewrite(self):
        val = f"({self.val})"
        typ = self.typ
        var = self.var
        if typ == self.Type.LET:
            return f"(stackReturn {val}) | (lambda {var}: "
        elif typ == self.Type.SEQ:
            return f"{val} >> ("
        elif typ == self.Type.BIND:
            return f"{val} | (lambda {var}: "

As you can see the rewrite rules are very simple. Once we known the type of the do-line and its components all we have to do is wrap the components in parentheses and use the correct monadic function ((>>) or (>>=)).

The do-line parsers

We will now work on points 3-6 of the above checklist. Here is our empty test skeleton:

from enum import IntEnum
from functools import partial, reduce
from typing import List
import unittest
import pyparsing as P


def literal(s): return P.Literal(s)
def literal_(s): return P.Suppress(P.Literal(s))
def word(s): return P.Word(s)
def word_(s): return P.Suppress(P.Word(s))


# These have to be implemented
parseLet = None
parseBind = None
parseSeq = None
parseDoLine = None


class ParserTest(unittest.TestCase):

    def test_parseLet(self):
        self.assertEqual(
                parseLet.parseString("let x = 42")[0],
                DoLine(DoLine.Type.LET, "42", "x")
                )

    def test_parseDoLine_with_let(self):
        self.assertEqual(
                parseDoLine.parseString("let x = 42")[0],
                DoLine(DoLine.Type.LET, "42", "x")
                )

    def test_parseSeq(self):
        self.assertEqual(
                parseSeq.parseString("func(arg1, arg2)")[0],
                DoLine(DoLine.Type.SEQ, "func(arg1, arg2)")
                )

    def test_parseDoLine_with_seq(self):
        self.assertEqual(
                parseDoLine.parseString("func(arg1, arg2)")[0],
                DoLine(DoLine.Type.SEQ, "func(arg1, arg2)")
                )

    def test_parseBind(self):
        self.assertEqual(
                parseBind.parseString("res <- func(arg1, arg2)")[0],
                DoLine(DoLine.Type.BIND, "func(arg1, arg2)", "res")
                )

    def test_parseDoLine_with_bind(self):
        self.assertEqual(
                parseDoLine.parseString("res <- func(arg1, arg2)")[0],
                DoLine(DoLine.Type.BIND, "func(arg1, arg2)", "res")
                )


if __name__ == '__main__':
    unittest.main()

Since the parsers are line based they are rather straight forward:

import pyparsing as P


def literal(s): return P.Literal(s)
def literal_(s): return P.Suppress(P.Literal(s))
def word(s): return P.Word(s)
def word_(s): return P.Suppress(P.Word(s))

def createDoLine(typ, parseResult):
    context = parseResult.asDict()
    return DoLine(typ, **context)

createLet = partial(createDoLine, DoLine.Type.LET)
createSeq = partial(createDoLine, DoLine.Type.SEQ)
createBind = partial(createDoLine, DoLine.Type.BIND)


parseLet = ( literal_("let")
           + word(P.alphanums).setResultsName("var")
           + literal_("=")
           + P.restOfLine.setResultsName("val")
           ).setParseAction(createLet)


parseBind = ( word(P.alphanums).setResultsName("var")
            + literal_("<-")
            + P.restOfLine.setResultsName("val")
            ).setParseAction(createBind)


parseSeq = P.restOfLine.setResultsName("val").setParseAction(createSeq)


parseDoLine = ( parseLet
              | parseBind
              | parseSeq
              )

The implementation of the do-function

Now we can write the do-function. If again included the compose and sequence functions from the extended example. With these functions and a few simple helper functions we can basically directly translate every bulletin point into a function composed together to form the overall do-function:

from functools import partial, reduce


def compose(*flist):
    def helper(*args, **kwargs):
        f0 = None
        if len(flist) == 0:
            pass
        elif len(flist) == 1:
            f0 = flist[0](*args, **kwargs)
        else:
            f0 = flist[0](*args, **kwargs)
            for fn in flist[1:]:
                f0 = fn(f0)
        return f0
    return helper

def sequence(*flist):
    if len(flist) == 0:
        return None
    else:
        f0 = flist[0]
        try:
            f0 = f0()
        except:
            pass
        return compose(*flist[1:])(f0)

def isComment(s: str):
    s = s.strip()
    return s and (not s.startswith("#"))

def toDoLine(s: str):
    return parseDoLine.parseString(s)[0]

def assertSeqAtEnd(ds):
    dl = list(ds)
    if len(dl) > 0 and dl[-1].typ != DoLine.Type.SEQ:
        raise ValueError("The last line of a do must be an expression")
    return dl

def deSugar(ds):
    dl = list(ds)
    if len(dl) > 0:
        return ([d.rewrite() for d in dl[:-1]] + ['(' + dl[-1].val + ')'])
    else:
        return []

def matchParentheses(s: str):
    openedPars = s.count('(')
    closedPars = s.count(')')
    return s + (openedPars - closedPars) * ')'


def do(s: str) -> str:
    """The do-function.
    Takes a string containing the computation written in Haskell style
    do-notation and returns a string with the computation rewritten using the
    bind (>>=) and sequence (>>) functions of the monad type class.

    This function will render the whole computation in one line without
    indentation.

    The returned string should be valid to be passed to Python's `eval`.
    """
    return sequence( s
                   , lambda s: s.split('\n')
                   , partial(filter, isComment)
                   , partial(map, toDoLine)
                   , assertSeqAtEnd
                   , deSugar
                   , lambda ss: ''.join(ss)
                   , matchParentheses
                   )

The complete code

Here is the final version of our do implementation:

from enum import IntEnum
from functools import partial, reduce
from typing import List
import unittest
import pyparsing as P


def literal(s): return P.Literal(s)
def literal_(s): return P.Suppress(P.Literal(s))
def word(s): return P.Word(s)
def word_(s): return P.Suppress(P.Word(s))


def compose(*flist):
    def helper(*args, **kwargs):
        f0 = None
        if len(flist) == 0:
            pass
        elif len(flist) == 1:
            f0 = flist[0](*args, **kwargs)
        else:
            f0 = flist[0](*args, **kwargs)
            for fn in flist[1:]:
                f0 = fn(f0)
        return f0
    return helper


def sequence(*flist):
    if len(flist) == 0:
        return None
    else:
        f0 = flist[0]
        try:
            f0 = f0()
        except:
            pass
        return compose(*flist[1:])(f0)


class DoLine:
    class Type(IntEnum):
        LET = 1
        SEQ = 2
        BIND = 3

    @classmethod
    def type2Str(cls, typ):
        s = ""
        if typ == cls.Type.LET:
            s = "LET"
        elif typ == cls.Type.SEQ:
            s = "SEQ"
        elif typ == cls.Type.BIND:
            s = "BIND"
        return s

    def __init__(self, typ, val: str, var:str =""):
        self.typ = typ
        self.val = val.strip()
        self.var = var.strip()

    def __str__(self):
        return (self.type2Str(self.typ) + " " + self.var + ": " + self.val)

    def __repr__(self):
        return self.__str__()

    def __eq__(self, other):
        return ((self.typ == other.typ) and
                (self.val == self.val) and
                (self.var == self.var))

    def rewrite(self):
        val = f"({self.val})"
        typ = self.typ
        var = self.var
        if typ == self.Type.LET:
            return f"(stackReturn {val}) | (lambda {var}: "
        elif typ == self.Type.SEQ:
            return f"{val} >> ("
        elif typ == self.Type.BIND:
            return f"{val} | (lambda {var}: "


def createDoLine(typ, parseResult):
    context = parseResult.asDict()
    return DoLine(typ, **context)

createLet = partial(createDoLine, DoLine.Type.LET)
createSeq = partial(createDoLine, DoLine.Type.SEQ)
createBind = partial(createDoLine, DoLine.Type.BIND)


parseLet = ( literal_("let")
           + word(P.alphanums).setResultsName("var")
           + literal_("=")
           + P.restOfLine.setResultsName("val")
           ).setParseAction(createLet)


parseBind = ( word(P.alphanums).setResultsName("var")
            + literal_("<-")
            + P.restOfLine.setResultsName("val")
            ).setParseAction(createBind)


parseSeq = P.restOfLine.setResultsName("val").setParseAction(createSeq)


parseDoLine = ( parseLet
              | parseBind
              | parseSeq
              )


def isComment(s: str):
    s = s.strip()
    return s and (not s.startswith("#"))


def toDoLine(s: str):
    return parseDoLine.parseString(s)[0]


def assertSeqAtEnd(ds):
    dl = list(ds)
    if len(dl) > 0 and dl[-1].typ != DoLine.Type.SEQ:
        raise ValueError("The last line of a do must be an expression")
    return dl


def deSugar(ds):
    dl = list(ds)
    if len(dl) > 0:
        return ([d.rewrite() for d in dl[:-1]] + ['(' + dl[-1].val + ')'])
    else:
        return []


def matchParentheses(s: str):
    openedPars = s.count('(')
    closedPars = s.count(')')
    return s + (openedPars - closedPars) * ')'


def do(s: str) -> str:
    """The do-function.
    Takes a string containing the computation written in Haskell style
    do-notation and returns a string with the computation rewritten using the
    bind (>>=) and sequence (>>) functions of the monad type class.

    This function will render the whole computation in one line without
    indentation.

    The returned string should be valid to be passed to Python's `eval`.
    """
    return sequence( s
                   , lambda s: s.split('\n')
                   , partial(filter, isComment)
                   , partial(map, toDoLine)
                   , assertSeqAtEnd
                   , deSugar
                   , lambda ss: ''.join(ss)
                   , matchParentheses
                   )


class ParserTest(unittest.TestCase):

    def test_parseLet(self):
        self.assertEqual(
                parseLet.parseString("let x = 42")[0],
                DoLine(DoLine.Type.LET, "42", "x")
                )

    def test_parseDoLine_with_let(self):
        self.assertEqual(
                parseDoLine.parseString("let x = 42")[0],
                DoLine(DoLine.Type.LET, "42", "x")
                )

    def test_parseSeq(self):
        self.assertEqual(
                parseSeq.parseString("func(arg1, arg2)")[0],
                DoLine(DoLine.Type.SEQ, "func(arg1, arg2)")
                )

    def test_parseDoLine_with_seq(self):
        self.assertEqual(
                parseDoLine.parseString("func(arg1, arg2)")[0],
                DoLine(DoLine.Type.SEQ, "func(arg1, arg2)")
                )

    def test_parseBind(self):
        self.assertEqual(
                parseBind.parseString("res <- func(arg1, arg2)")[0],
                DoLine(DoLine.Type.BIND, "func(arg1, arg2)", "res")
                )

    def test_parseDoLine_with_bind(self):
        self.assertEqual(
                parseDoLine.parseString("res <- func(arg1, arg2)")[0],
                DoLine(DoLine.Type.BIND, "func(arg1, arg2)", "res")
                )


class DoTest(unittest.TestCase):

    def test_empty(self):
        self.assertEqual(do(""), "")

    def test_oneline_seq_is_ok(self):
        test_input = 'putStrLn "Test"'
        expected = '(putStrLn "Test")'
        self.assertEqual(do(test_input), expected)

    def test_oneline_bind_fails(self):
        test_input = 'line <- getLine'
        with self.assertRaises(ValueError):
            do(test_input)

    def test_oneline_let_fails(self):
        test_input = 'let x = 42'
        with self.assertRaises(ValueError):
            do(test_input)

    def test_sequence_only(self):
        test_input = """
        putStrLn "Test"
        putStrLn "Test2"
        stackReturn ()
        """
        expected = (
                '(putStrLn "Test")' + ' >> (' +
                '(putStrLn "Test2")' + ' >> (' +
                '(stackReturn ())' +
                '))'
                )
        self.assertEqual(do(test_input), expected)

    def test_one_bind(self):
        test_input = """
        line <- getLine
        putStrLn line
        """
        expected = (
                '(getLine)' + ' | (lambda line: ' +
                '(putStrLn line)' +
                ')'
                )

    def test_two_binds(self):
        test_input = """
        line <- getLine
        line2 <- getLine
        putStrLn line
        putStrLn line2
        """
        expected = (
                '(getLine)' + ' | (lambda line: ' +
                '(getLine)' + ' | (lambda line2: ' +
                '(putStrLn line)' + ' >> (' +
                '(putStrLn line2)' +
                ')))'
                )
        self.assertEqual(do(test_input), expected)

    def test_simple_let(self):
        test_input = """
        let x = 42
        putStrLn (show x)
        """
        expected = (
                '(stackReturn (42))' + ' | (lambda x: ' +
                '(putStrLn (show x))' +
                ')'
                )
        self.assertEqual(do(test_input), expected)

    def test_ignore_empty_line(self):
        test_input = """
        putStrLn "foo"

        putStrLn "bar"
        """
        expected = (
                '(putStrLn "foo")' + ' >> (' +
                '(putStrLn "bar")' +
                ')'
                )
        self.assertEqual(do(test_input), expected)

    def test_ignore_comments(self):
        test_input = """
        # first comment
        res <- complicatedComputation(arg1, arg2)
        # second comment
        shell("more stuff", res)
        """
        expected = (
                '(complicatedComputation(arg1, arg2))' + ' | (lambda res: ' +
                '(shell("more stuff", res))' +
                ')'
                )
        self.assertEqual(do(test_input), expected)

    def test_example(self):
        test_input = """
        putStr "How many gos do you want? "
        line <- getLine
        let cnt = read line :: Int
        sequence_ $ replicate goCnt (putStr "go")
        putStrLn "!"
        """
        expected = (
                '(putStr "How many gos do you want? ")' + ' >> (' +
                '(getLine)' + ' | (lambda line: ' +
                '(stackReturn (read line :: Int))' + ' | (lambda cnt: ' +
                '(sequence_ $ replicate goCnt (putStr "go"))' + ' >> (' +
                '(putStrLn "!")' +
                '))))'
                )
        self.assertEqual(do(test_input), expected)


if __name__ == '__main__':
    unittest.main()

Using the do-function

We are finally able to rewrite the last example of the original post post using our newly won do-function. Compare it with the equivalent Haskell code that was also given.

from transformers import *
from do_final import do


def act_on_branch(branch):
    if branch == 'master':
        return shell("echo 'action for master'")
    else:
        return shell("echo 'Other action'")

action = eval(do("""
        bOut <- shell('git branch')
        let branch = bOut.strip().split()[-1]
        act_on_branch(branch)
        shell("echo 'command independent of previous commands'")
        shell(f"echo {branch}")
        cnt <- shell('git tag | wc -l')
        let n = int(cnt)
        shell(f"/bin/zsh -c 'for i in {{1..{n}}}; do; echo 'Branch!'; done'")
        """))

res, info = runAction(action)
print(f"Final result: {res}")
print("== INFO ==")
print(reduce(lambda x, y: x + '\n' + y, info))

There it is - do in Python. That’s a big gain in readability! Note that we have to use eval to evaluate the code. We have created our own little DSL that our do-function translates (transcompiles) to Python source code that can be evaluated by the Python interpreter.

Only one small drop of bitterness left: Because our code is a string, we loose all the nice syntactic highlighting.