Understanding monads
Table of Contents
1. Introduction
This article is about monads in the context of programming, specifically about a design pattern that is often used in purely functional languages, such as Haskell. This concept originates from category theory, but this article won’t be focusing on that.
Although I haven’t personally used Haskell a lot, I have known about monads for a long time, probably since I heard about the xmonad window manager. After researching this topic for a few days, I decided to write a small article about what I learned, since it might be useful for some people. This article won’t focus specifically on Haskell, so most concepts can be applied to many languages, and not a lot of programming background is required from the reader.
Lastly, I would like to mention that this article won’t be focusing specifically on Haskell; if you are interested in that, I recommend you check out the articles on this link.
2. Pure functions
Before trying to understand monads and their use cases, it’s important to know about pure functions, and why they are desirable. In order for a function to be pure, it must satisfy 2 properties:
- The value returned by the function depends only on its inputs, so identical values are returned across multiple calls with identical inputs.
- The function must not have side effects, such as modifying global variables, input/output streams, etc.
For example, the following python functions are pure1, since they don’t break any of the rules above.
def double_sum(a, b): return (a + b) * 2 def count_uppercase(string): result = 0 for character in string: if character.isupper(): result += 1 return result;
However, the following are not pure, because they are either influenced by the external program state, or because they alter it.
my_global = 0 # Accesses external state. def global_sum(n): return my_global + n # Modifies external state (global variable). def impure_sum(a, b): my_global = 123 return a + b # Modifies external state (I/O stream). def print_sum(a, b): result = a + b print(f"Result: {result}") return result
Pure functions offer many advantages over impure functions, for both the machine (compiler, interpreter) and the human (designing the architecture, programming, testing, documenting). Some of these advantages include:
- Result caching (Memoization): If a pure function is repeatedly called with the same arguments, since it’s guaranteed to return the same result given the same inputs, it could be stored in a look-up table, for better performance when called multiple times.
- Better optimizations: If the compiler/interpreter knows that a function is pure, it can safely do many more optimizations, including result caching, or skipping calls altogether2.
- Predictability and security: From the programmer’s point of view, it’s much easier to test a pure function than an impure one, since it doesn’t depend on an external environment that might make the tests more unreliable or complex.
3. From impure to monadic
Before diving into what monads are, it’s important to understand the problem that they attempt to solve. Monads are useful when a series of steps or computations (e.g. function calls) need to be combined, and when these computations somehow extend an input type, adding some kind of context. For example, the programmer might want to combine or link a series of functions whose arguments are integers, but whose returned data type is a more complex structure that extends the received integer type.
3.1. A simple writer
To illustrate this problem, along with a possible solution, the writer monad will be used, which enables functions that operate on a base type (e.g. on integers) to also write a series of logs into a common list.
Below is some Python code that implements this writer functionality without using a monad. Each function performs some computation (in this case, simple arithmetic operations), while also logging a string into a global list variable.
global_log_list = [] def add(a, b): global_log_list.append(f"Added {a} to {b}") return a + b def sub(a, b): global_log_list.append(f"Subtracted {a} from {b}") return a - b def mul(a, b): global_log_list.append(f"Multiplied {a} to {b}") return a * b def div(a, b): global_log_list.append(f"Divided {a} by {b}") return a / b
To combine calls to these functions with the current design, one may store the result of each call in a variable, pass it to the next function as an argument, and overwrite the stored result with the new returned value. After the chain of computations is done, the final result is stored, along with the list of logs that were generated by the functions.
result = add(6, 5) result = sub(result, 4) result = mul(result, 3) result = div(result, 2) print(f"Final result: {result}") print("Logs:") for line in global_log_list: print(f" * {line}")
The previous code would print the following output:
Final result: 10.5 Logs: * Added 6 to 5 * Subtracted 11 from 4 * Multiplied 7 to 3 * Divided 21 by 2
Notice how the internal structure of those functions is very similar, and since much behavior is shared, they could be further abstracted. Also note how the functions are not pure because they produce side effects by modifying a global variable.
3.2. Making the writer pure
Instead of modifying a global list, these functions could return their log line as part of the result, making them pure. In order to chain multiple functions, combining their logs, they could also receive the previous log list as part of their arguments.
First, a new data type should be defined, which extends the integer type by adding the log list.
class LoggedInt: def __init__(self, val, logs): self.val = val self.logs = logs # Example logged_int = LoggedInt(5, ["Some log line", "Another log line"])
The arithmetic functions can be modified to receive and return this new data
type, appending the new log line to the previous log list. Note how the first
argument of the following functions is a LoggedInt
, but the second argument is
still a simple integer.
def add(logged_a, b): return LoggedInt( logged_a.val + b, # New value logged_a.logs + [f"Added {logged_a.val} to {b}"] # Extended log list ) def sub(logged_a, b): return LoggedInt( logged_a.val - b, logged_a.logs + [f"Subtracted {logged_a.val} from {b}"] ) def mul(logged_a, b): return LoggedInt( logged_a.val * b, logged_a.logs + [f"Multiplied {logged_a.val} to {b}"] ) def div(logged_a, b): return LoggedInt( logged_a.val / b, logged_a.logs + [f"Divided {logged_a.val} by {b}"] )
The usage of these functions is similar to the previous ones, but since they now
receive a LoggedInt
as their first argument, the first input integer needs to be
promoted to a LoggedInt
, initially with an empty log list.
logged_result = LoggedInt(6, []) logged_result = add(logged_result, 5) logged_result = sub(logged_result, 4) logged_result = mul(logged_result, 3) logged_result = div(logged_result, 2) print(f"Final result: {logged_result.val}") print("Logs:") for line in logged_result.logs: print(f" * {line}")
With this simple change, the functions are now pure. At this point, however, this design pattern isn’t exactly a monad, and some of the shared logic can be extracted into separate functions.
3.3. Extracting the binding logic
The previous code can be further abstracted by moving the “combination logic”
into a separate binding function. Before defining this bind
function, the
arithmetic functions should be modified so they return a LoggedInt
while still
receiving unwrapped integers.
def add(a, b): return LoggedInt( a + b, # New value [f"Added {a} to {b}"] # Written log line ) def sub(a, b): return LoggedInt( a - b, [f"Subtracted {a} from {b}"] ) def mul(a, b): return LoggedInt( a * b, [f"Multiplied {a} to {b}"] ) def div(a, b): return LoggedInt( a / b, [f"Divided {a} by {b}"] )
The functions now receive two simple integers, and return a new LoggedInt
that
contains the result value and the log line written by that specific
function. Note how the log line needs to be wrapped in a one-element list, since
the LoggedInt
type expects a log list, not a string.
Now that the combination logic has been removed from the arithmetic functions,
the bind
function can be implemented, which receives a LoggedInt
value, one of
the arithmetic functions and a simple integer. It performs the following steps:
- Unwrap/extract the original integer value from
a
, the receivedLoggedInt
. - Call the arithmetic function (received as an argument) with the unwrapped
value and
b
, the received simple integer. - Combine the logs of the received
LoggedInt
with the logs of theLoggedInt
that were returned by the arithmetic function.
Through this process, it applies the received function to the other two values,
and combines that result with the original LoggedInt
value.
def bind(old_logged_int, function, b): unwrapped_val = old_logged_int.val new_logged_int = function(unwrapped_val, b) return LoggedInt( new_logged_int.val, old_logged_int.logs + new_logged_int.logs )
Instead of calling the arithmetic functions directly, they are now passed as
arguments to bind
, which will call the function and combine the logs, returning
a new LoggedInt
result.
logged_result = LoggedInt(6, []) logged_result = bind(logged_result, add, 5) logged_result = bind(logged_result, sub, 4) logged_result = bind(logged_result, mul, 3) logged_result = bind(logged_result, div, 2) print(f"Final result: {logged_result.val}") print("Logs:") for line in logged_result.logs: print(f" * {line}")
Furthermore, the first input doesn’t need to be promoted into a LoggedInt
explicitly anymore, since the arithmetic functions now receive simple integers.
logged_result = add(6, 5) # No explicit call to 'LoggedInt' logged_result = bind(logged_result, sub, 4) logged_result = bind(logged_result, mul, 3) # ...
3.4. Making the writer a monad
In order to turn the writer code into a monad, there is one last change that
needs to be made. The current bind
function receives 3 arguments, the last one
being a simple integer because it’s what the arithmetic functions expect. The
bind
function of a proper monad should only receive 2 arguments: a value, whose
type is monadic (e.g. LoggedInt
), and a monadic function, which receives a
simple value (e.g. an integer) and returns a new monadic value.
def bind(old_logged_int, function): # Receives two arguments unwrapped_val = old_logged_int.val new_logged_int = function(unwrapped_val) # Called with one argument return LoggedInt( new_logged_int.val, old_logged_int.logs + new_logged_int.logs )
After this change, how could the new bind
function receive the arithmetic
functions, if they receive two arguments, a
and b
? This problem has an easy
solution, although it’s not particularly pretty depending on the programming
language. All functions can be converted into one-argument functions by
returning a lambda. For example, the following two function calls are
equivalent.
# Define a function that receives integers 'a', 'b' and 'c', and returns an # integer with the result. def foo(a, b, c): return a + b * c # Define a function that receives an integer 'a', and returns an anonymous # function that receives an integer 'b', and returns an anonymous function that # receives an integer 'c' and returns an integer with the result. def bar(a): return lambda b: lambda c: a + b * c # Example calls. foo(5, 6, 7) bar(5)(6)(7)
Therefore, the arithmetic functions themselves don’t need to be modified, since the following expressions would be equivalent:
add(5, 6) # Equivalent one argument function. add_six = lambda a: add(a, 6) add_six(5)
This is how the bind functions would be called to match the previous example.
logged_result = add(6, 5) logged_result = bind(logged_result, lambda a: sub(a, 4)) logged_result = bind(logged_result, lambda a: mul(a, 3)) logged_result = bind(logged_result, lambda a: div(a, 2))
Furthermore, using an object-oriented approach, the bind
function can be
converted to a method of LoggedInt
, allowing the caller to bind functions with a
cleaner notation, since it now accesses the instance of the object.
class LoggedInt: def __init__(self, val, logs): # Unchanged self.val = val self.logs = logs def bind(self, function): new_logged_int = function(self.val) return LoggedInt( new_logged_int.val, self.logs + new_logged_int.logs ) # Example usage. logged_result = ( add(6, 5).bind(lambda a: sub(a, 4)) .bind(lambda a: mul(a, 3)) .bind(lambda a: div(a, 2)) )
3.5. Final code
Some small modifications can be made, like adding default value to the second argument of the constructor. There is also a new one-argument monadic function for squaring a number.
This is the final python code for the writer monad.
# Monadic type, expands a base integer type to add logging functionality. class LoggedInt: def __init__(self, val, logs=[]): self.val = val self.logs = logs # Applies a one-argument monadic function to the current instance, and # combines the result with the existing log list. def bind(self, function): new_logged_int = function(self.val) return LoggedInt( new_logged_int.val, self.logs + new_logged_int.logs ) # Monadic functions. def add(a, b): return LoggedInt(a + b, [f"Added {a} to {b}"]) def sub(a, b): return LoggedInt(a - b, [f"Subtracted {a} from {b}"]) def mul(a, b): return LoggedInt(a * b, [f"Multiplied {a} to {b}"]) def div(a, b): return LoggedInt(a / b, [f"Divided {a} by {b}"]) def square(a): return LoggedInt(a * a, [f"Squared {a}"]) # Example usage. logged_result = ( add(6, 5).bind(lambda a: sub(a, 4)) .bind(lambda a: mul(a, 3)) .bind(lambda a: div(a, 2)) .bind(lambda a: square(a)) ) print(f"Final result: {logged_result.val}") print("Logs:") for line in logged_result.logs: print(f" * {line}")
4. What are monads?
A monad is a design pattern that is often used in functional programming for encapsulating and combining a series of steps or computations (i.e. function calls) within a context, allowing them to produce side effects in that controlled environment, keeping the functions pure.
In the previous example, the LoggedInt
data type implements the writer monad by
extending a base integer type, where the logs are the side effects that should
be encapsulated within the monadic context, and the bind
function provides the
means for combining functions in this environment.
A “monad” is not a specific data type, like Integer
or String
, it is a design
pattern that certain types might implement. It is similar to an interface in
object-oriented programming.
4.1. Properties of monads
A monad must implement a monadic type (LoggedInt
), and two operations:
- A constructor or wrapper3, which receives the base type (e.g. an integer) and
wraps it into the monadic type. In the previous example, this was the
__init__
constructor ofLoggedInt
. - A binding function4, which receives a monadic value and a monadic function. It applies the monadic function to the unwrapped value, and preserves the old monadic context by somehow joining or combining it with the one returned by the monadic function.
Note how the monad doesn’t need to provide any “unwrapping” functionality that
allows the caller to extract a simple value from a monadic one; that logic is
part of the bind
function, which unwraps the monadic value it receives before
applying the monadic function to it.
Furthermore, monads must satisfy 3 laws that determine the behavior of the constructor and the binding function:
- Left identity
Promoting a simple value into the monadic type through the constructor, and then binding it to a monadic function is equivalent to applying the function to the simple value directly.
The previous writer example, satisfies this law, because the following expressions are equivalent:
# Call the monadic function with simple value. result1 = square(6) # Promote simple value to monadic type, and bind it to monadic function. result2 = LoggedInt(6).bind(square)
- Right identity
Binding a monadic value to the constructor doesn’t affect the monadic value. This ensures the constructor doesn’t alter the monadic environment.
In the previous writer example, this guarantees that the monadic constructor doesn’t initialize the log list (i.e. the monadic context) with any information. This can be checked because the following expressions are equivalent:
# Promote a simple value to a monadic type, through the constructor. result1 = LoggedInt(6) # Promote a simple value to a monadic type, and bind that to the constructor. result2 = LoggedInt(6).bind(LoggedInt)
- Associativity
The order in which monadic values are combined does not affect the result. Note that the order in which monadic functions are applied to monadic values does matter; this property specifies that the application order should not affect the combination or linking process performed by
bind
.In the previous writer example:
# Define another simple one-argument monadic function. def add_five(a): return add(a, 5) # Promote a simple value to a monadic type, bind that to 'square', and then # bind the result to 'add_five'. result1 = LoggedInt(6).bind(square).bind(add_five) # Promote a simple value to a monadic type, and bind that to a lambda which # calls 'square' on its input and binds that to 'add_five'. result2 = LoggedInt(6).bind(lambda a: square(a).bind(add_five))
Footnotes:
On lower-level
programming languages, like C, one might argue that string-manipulation
functions are not actually pure, since they often receive a pointer whose value
might change across calls. I still decided to categorize the count_uppercase
function as pure, in higher-level programming languages, since it produces
identical results when given the same string inputs.
For example, if the length of the same string is calculated multiple times, and the string doesn’t change, the compiler/interpreter could perform a single call and reuse that value.
This function is called return
or pure
in
Haskell, but those names can be confusing, specially when used outside a do
block. Note that pure
is part of the Applicative
typeclass, which is a
supperclass of Monad
.
This function is usually called bind
, and in Haskell
it is available as an >>=
operator.