Up | Home

Understanding monads

Table of Contents

1. Introduction

This article is about monads in the context of programming, specifically about a design pattern that is often used in purely functional languages, such as Haskell. This concept originates from category theory, but this article won’t be focusing on that.

Although I haven’t personally used Haskell a lot, I have known about monads for a long time, probably since I heard about the xmonad window manager. After researching this topic for a few days, I decided to write a small article about what I learned, since it might be useful for some people. This article won’t focus specifically on Haskell, so most concepts can be applied to many languages, and not a lot of programming background is required from the reader.

Lastly, I would like to mention that this article won’t be focusing specifically on Haskell; if you are interested in that, I recommend you check out the articles on this link.

2. Pure functions

Before trying to understand monads and their use cases, it’s important to know about pure functions, and why they are desirable. In order for a function to be pure, it must satisfy 2 properties:

  1. The value returned by the function depends only on its inputs, so identical values are returned across multiple calls with identical inputs.
  2. The function must not have side effects, such as modifying global variables, input/output streams, etc.

For example, the following python functions are pure1, since they don’t break any of the rules above.

def double_sum(a, b):
    return (a + b) * 2

def count_uppercase(string):
    result = 0
    for character in string:
        if character.isupper():
            result += 1
    return result;

However, the following are not pure, because they are either influenced by the external program state, or because they alter it.

my_global = 0

# Accesses external state.
def global_sum(n):
    return my_global + n

# Modifies external state (global variable).
def impure_sum(a, b):
    my_global = 123
    return a + b

# Modifies external state (I/O stream).
def print_sum(a, b):
    result = a + b
    print(f"Result: {result}")
    return result

Pure functions offer many advantages over impure functions, for both the machine (compiler, interpreter) and the human (designing the architecture, programming, testing, documenting). Some of these advantages include:

  • Result caching (Memoization): If a pure function is repeatedly called with the same arguments, since it’s guaranteed to return the same result given the same inputs, it could be stored in a look-up table, for better performance when called multiple times.
  • Better optimizations: If the compiler/interpreter knows that a function is pure, it can safely do many more optimizations, including result caching, or skipping calls altogether2.
  • Predictability and security: From the programmer’s point of view, it’s much easier to test a pure function than an impure one, since it doesn’t depend on an external environment that might make the tests more unreliable or complex.

3. From impure to monadic

Before diving into what monads are, it’s important to understand the problem that they attempt to solve. Monads are useful when a series of steps or computations (e.g. function calls) need to be combined, and when these computations somehow extend an input type, adding some kind of context. For example, the programmer might want to combine or link a series of functions whose arguments are integers, but whose returned data type is a more complex structure that extends the received integer type.

3.1. A simple writer

To illustrate this problem, along with a possible solution, the writer monad will be used, which enables functions that operate on a base type (e.g. on integers) to also write a series of logs into a common list.

Below is some Python code that implements this writer functionality without using a monad. Each function performs some computation (in this case, simple arithmetic operations), while also logging a string into a global list variable.

global_log_list = []

def add(a, b):
    global_log_list.append(f"Added {a} to {b}")
    return a + b

def sub(a, b):
    global_log_list.append(f"Subtracted {a} from {b}")
    return a - b

def mul(a, b):
    global_log_list.append(f"Multiplied {a} to {b}")
    return a * b

def div(a, b):
    global_log_list.append(f"Divided {a} by {b}")
    return a / b

To combine calls to these functions with the current design, one may store the result of each call in a variable, pass it to the next function as an argument, and overwrite the stored result with the new returned value. After the chain of computations is done, the final result is stored, along with the list of logs that were generated by the functions.

result = add(6, 5)
result = sub(result, 4)
result = mul(result, 3)
result = div(result, 2)

print(f"Final result: {result}")
print("Logs:")
for line in global_log_list:
    print(f"  * {line}")

The previous code would print the following output:

Final result: 10.5
Logs:
  * Added 6 to 5
  * Subtracted 11 from 4
  * Multiplied 7 to 3
  * Divided 21 by 2

Notice how the internal structure of those functions is very similar, and since much behavior is shared, they could be further abstracted. Also note how the functions are not pure because they produce side effects by modifying a global variable.

3.2. Making the writer pure

Instead of modifying a global list, these functions could return their log line as part of the result, making them pure. In order to chain multiple functions, combining their logs, they could also receive the previous log list as part of their arguments.

First, a new data type should be defined, which extends the integer type by adding the log list.

class LoggedInt:
    def __init__(self, val, logs):
        self.val = val
        self.logs = logs

# Example
logged_int = LoggedInt(5, ["Some log line", "Another log line"])

The arithmetic functions can be modified to receive and return this new data type, appending the new log line to the previous log list. Note how the first argument of the following functions is a LoggedInt, but the second argument is still a simple integer.

def add(logged_a, b):
    return LoggedInt(
        logged_a.val + b,  # New value
        logged_a.logs + [f"Added {logged_a.val} to {b}"]  # Extended log list
    )

def sub(logged_a, b):
    return LoggedInt(
        logged_a.val - b,
        logged_a.logs + [f"Subtracted {logged_a.val} from {b}"]
    )

def mul(logged_a, b):
    return LoggedInt(
        logged_a.val * b,
        logged_a.logs + [f"Multiplied {logged_a.val} to {b}"]
    )

def div(logged_a, b):
    return LoggedInt(
        logged_a.val / b,
        logged_a.logs + [f"Divided {logged_a.val} by {b}"]
    )

The usage of these functions is similar to the previous ones, but since they now receive a LoggedInt as their first argument, the first input integer needs to be promoted to a LoggedInt, initially with an empty log list.

logged_result = LoggedInt(6, [])
logged_result = add(logged_result, 5)
logged_result = sub(logged_result, 4)
logged_result = mul(logged_result, 3)
logged_result = div(logged_result, 2)

print(f"Final result: {logged_result.val}")
print("Logs:")
for line in logged_result.logs:
    print(f"  * {line}")

With this simple change, the functions are now pure. At this point, however, this design pattern isn’t exactly a monad, and some of the shared logic can be extracted into separate functions.

3.3. Extracting the binding logic

The previous code can be further abstracted by moving the “combination logic” into a separate binding function. Before defining this bind function, the arithmetic functions should be modified so they return a LoggedInt while still receiving unwrapped integers.

def add(a, b):
    return LoggedInt(
        a + b,  # New value
        [f"Added {a} to {b}"]  # Written log line
    )

def sub(a, b):
    return LoggedInt(
        a - b,
        [f"Subtracted {a} from {b}"]
    )

def mul(a, b):
    return LoggedInt(
        a * b,
        [f"Multiplied {a} to {b}"]
    )

def div(a, b):
    return LoggedInt(
        a / b,
        [f"Divided {a} by {b}"]
    )

The functions now receive two simple integers, and return a new LoggedInt that contains the result value and the log line written by that specific function. Note how the log line needs to be wrapped in a one-element list, since the LoggedInt type expects a log list, not a string.

Now that the combination logic has been removed from the arithmetic functions, the bind function can be implemented, which receives a LoggedInt value, one of the arithmetic functions and a simple integer. It performs the following steps:

  1. Unwrap/extract the original integer value from a, the received LoggedInt.
  2. Call the arithmetic function (received as an argument) with the unwrapped value and b, the received simple integer.
  3. Combine the logs of the received LoggedInt with the logs of the LoggedInt that were returned by the arithmetic function.

Through this process, it applies the received function to the other two values, and combines that result with the original LoggedInt value.

def bind(old_logged_int, function, b):
    unwrapped_val = old_logged_int.val
    new_logged_int = function(unwrapped_val, b)
    return LoggedInt(
        new_logged_int.val,
        old_logged_int.logs + new_logged_int.logs
    )

Instead of calling the arithmetic functions directly, they are now passed as arguments to bind, which will call the function and combine the logs, returning a new LoggedInt result.

logged_result = LoggedInt(6, [])
logged_result = bind(logged_result, add, 5)
logged_result = bind(logged_result, sub, 4)
logged_result = bind(logged_result, mul, 3)
logged_result = bind(logged_result, div, 2)

print(f"Final result: {logged_result.val}")
print("Logs:")
for line in logged_result.logs:
    print(f"  * {line}")

Furthermore, the first input doesn’t need to be promoted into a LoggedInt explicitly anymore, since the arithmetic functions now receive simple integers.

logged_result = add(6, 5)  # No explicit call to 'LoggedInt'
logged_result = bind(logged_result, sub, 4)
logged_result = bind(logged_result, mul, 3)
# ...

3.4. Making the writer a monad

In order to turn the writer code into a monad, there is one last change that needs to be made. The current bind function receives 3 arguments, the last one being a simple integer because it’s what the arithmetic functions expect. The bind function of a proper monad should only receive 2 arguments: a value, whose type is monadic (e.g. LoggedInt), and a monadic function, which receives a simple value (e.g. an integer) and returns a new monadic value.

def bind(old_logged_int, function): # Receives two arguments
    unwrapped_val = old_logged_int.val
    new_logged_int = function(unwrapped_val)  # Called with one argument
    return LoggedInt(
        new_logged_int.val,
        old_logged_int.logs + new_logged_int.logs
    )

After this change, how could the new bind function receive the arithmetic functions, if they receive two arguments, a and b? This problem has an easy solution, although it’s not particularly pretty depending on the programming language. All functions can be converted into one-argument functions by returning a lambda. For example, the following two function calls are equivalent.

# Define a function that receives integers 'a', 'b' and 'c', and returns an
# integer with the result.
def foo(a, b, c):
    return a + b * c

# Define a function that receives an integer 'a', and returns an anonymous
# function that receives an integer 'b', and returns an anonymous function that
# receives an integer 'c' and returns an integer with the result.
def bar(a):
    return lambda b: lambda c: a + b * c

# Example calls.
foo(5, 6, 7)
bar(5)(6)(7)

Therefore, the arithmetic functions themselves don’t need to be modified, since the following expressions would be equivalent:

add(5, 6)

# Equivalent one argument function.
add_six = lambda a: add(a, 6)
add_six(5)

This is how the bind functions would be called to match the previous example.

logged_result = add(6, 5)
logged_result = bind(logged_result, lambda a: sub(a, 4))
logged_result = bind(logged_result, lambda a: mul(a, 3))
logged_result = bind(logged_result, lambda a: div(a, 2))

Furthermore, using an object-oriented approach, the bind function can be converted to a method of LoggedInt, allowing the caller to bind functions with a cleaner notation, since it now accesses the instance of the object.

class LoggedInt:
    def __init__(self, val, logs):  # Unchanged
        self.val = val
        self.logs = logs

    def bind(self, function):
        new_logged_int = function(self.val)
        return LoggedInt(
            new_logged_int.val,
            self.logs + new_logged_int.logs
        )

# Example usage.
logged_result = (
    add(6, 5).bind(lambda a: sub(a, 4))
             .bind(lambda a: mul(a, 3))
             .bind(lambda a: div(a, 2))
)

3.5. Final code

Some small modifications can be made, like adding default value to the second argument of the constructor. There is also a new one-argument monadic function for squaring a number.

This is the final python code for the writer monad.

# Monadic type, expands a base integer type to add logging functionality.
class LoggedInt:
    def __init__(self, val, logs=[]):
        self.val = val
        self.logs = logs

    # Applies a one-argument monadic function to the current instance, and
    # combines the result with the existing log list.
    def bind(self, function):
        new_logged_int = function(self.val)
        return LoggedInt(
            new_logged_int.val,
            self.logs + new_logged_int.logs
        )

# Monadic functions.
def add(a, b):
    return LoggedInt(a + b, [f"Added {a} to {b}"])
def sub(a, b):
    return LoggedInt(a - b, [f"Subtracted {a} from {b}"])
def mul(a, b):
    return LoggedInt(a * b, [f"Multiplied {a} to {b}"])
def div(a, b):
    return LoggedInt(a / b, [f"Divided {a} by {b}"])
def square(a):
    return LoggedInt(a * a, [f"Squared {a}"])

# Example usage.
logged_result = (
    add(6, 5).bind(lambda a: sub(a, 4))
             .bind(lambda a: mul(a, 3))
             .bind(lambda a: div(a, 2))
             .bind(lambda a: square(a))
)

print(f"Final result: {logged_result.val}")
print("Logs:")
for line in logged_result.logs:
    print(f"  * {line}")

4. What are monads?

A monad is a design pattern that is often used in functional programming for encapsulating and combining a series of steps or computations (i.e. function calls) within a context, allowing them to produce side effects in that controlled environment, keeping the functions pure.

In the previous example, the LoggedInt data type implements the writer monad by extending a base integer type, where the logs are the side effects that should be encapsulated within the monadic context, and the bind function provides the means for combining functions in this environment.

A “monad” is not a specific data type, like Integer or String, it is a design pattern that certain types might implement. It is similar to an interface in object-oriented programming.

4.1. Properties of monads

A monad must implement a monadic type (LoggedInt), and two operations:

  1. A constructor or wrapper3, which receives the base type (e.g. an integer) and wraps it into the monadic type. In the previous example, this was the __init__ constructor of LoggedInt.
  2. A binding function4, which receives a monadic value and a monadic function. It applies the monadic function to the unwrapped value, and preserves the old monadic context by somehow joining or combining it with the one returned by the monadic function.

Note how the monad doesn’t need to provide any “unwrapping” functionality that allows the caller to extract a simple value from a monadic one; that logic is part of the bind function, which unwraps the monadic value it receives before applying the monadic function to it.

Furthermore, monads must satisfy 3 laws that determine the behavior of the constructor and the binding function:

Left identity

Promoting a simple value into the monadic type through the constructor, and then binding it to a monadic function is equivalent to applying the function to the simple value directly.

The previous writer example, satisfies this law, because the following expressions are equivalent:

# Call the monadic function with simple value.
result1 = square(6)

# Promote simple value to monadic type, and bind it to monadic function.
result2 = LoggedInt(6).bind(square)
Right identity

Binding a monadic value to the constructor doesn’t affect the monadic value. This ensures the constructor doesn’t alter the monadic environment.

In the previous writer example, this guarantees that the monadic constructor doesn’t initialize the log list (i.e. the monadic context) with any information. This can be checked because the following expressions are equivalent:

# Promote a simple value to a monadic type, through the constructor.
result1 = LoggedInt(6)

# Promote a simple value to a monadic type, and bind that to the constructor.
result2 = LoggedInt(6).bind(LoggedInt)
Associativity

The order in which monadic values are combined does not affect the result. Note that the order in which monadic functions are applied to monadic values does matter; this property specifies that the application order should not affect the combination or linking process performed by bind.

In the previous writer example:

# Define another simple one-argument monadic function.
def add_five(a):
    return add(a, 5)

# Promote a simple value to a monadic type, bind that to 'square', and then
# bind the result to 'add_five'.
result1 = LoggedInt(6).bind(square).bind(add_five)

# Promote a simple value to a monadic type, and bind that to a lambda which
# calls 'square' on its input and binds that to 'add_five'.
result2 = LoggedInt(6).bind(lambda a: square(a).bind(add_five))

Footnotes:

1

On lower-level programming languages, like C, one might argue that string-manipulation functions are not actually pure, since they often receive a pointer whose value might change across calls. I still decided to categorize the count_uppercase function as pure, in higher-level programming languages, since it produces identical results when given the same string inputs.

2

For example, if the length of the same string is calculated multiple times, and the string doesn’t change, the compiler/interpreter could perform a single call and reuse that value.

3

This function is called return or pure in Haskell, but those names can be confusing, specially when used outside a do block. Note that pure is part of the Applicative typeclass, which is a supperclass of Monad.

4

This function is usually called bind, and in Haskell it is available as an >>= operator.