LIANNE LARDIZABAL

Software Engineer, BJJ Practioner, Artist
Computer Science @ The University of British Columbia
Vancouver, BC

Currying in Scala

Abstractions, combined with higher-order functions, can be used to create new abstractions that improve the computational power of a programming language.

Suppose we define a few summation functions:

def sumInts(s: Int, b: Int) = sum(x => x, a, b)
def sumCubes(s: Int, b: Int) = sum(x => x * x * x, a, b)
def sumFactorials(s: Int, b: Int) = sum(fact, a, b)

While these functions are already quite simple, the methods appear redundant. Is there a way to generalize them?

One option is to reduce the number of parameters passed into these functions. We can abstract functions using a technique called currying. Currying is a form of function abstraction where a function that takes in multiple arguments is reconstructed into a series of functions, where each function takes in a single argument. To support this method, the Scala programming language allows us to pass in function arguments.

Becoming Functional

We see that our summation functions are very similar. Update the function sum to a more general function.

def sum(f: Int => Int): (Int, Int) => Int = {
    def sum(a: Int, b: Int): Int =
        if (a > b) 0
        else f(a) + sumF(a + 1, b)
    sumF    
}

Rewrite the summation functions using our new general function sum.

def sumInts(s: Int, b: Int) = sum(x => x)
def sumCubes(s: Int, b: Int) = sum(x => x * x * x)
def sumFactorials(s: Int, b: Int) = sum(fact)

To evaluate a sum of cubes bounded from 1 to 10, now we simply call sumCubes(1,10).

That is an improvement. However, each function is still very specific to a particular summation problem. Let’s abstract our functions even further.

Further Abstraction

We do not need to create independent summation functions. Instead, we can use the general summation function that we defined previously. To calculate a sum of cubes from 1 to 10, we write:

sum(cube)(1,10)

Definitions

A function with multiple parameters:

def f(args_1)...(args_n) = E

where n > 1 is equivalent to:

def f(args_1)...(args_n) = {def g(args_n) = E;g}

where g is a fresh identifier.

Alternatively,

def f(args_1)...(args_n-1) = {args_n => E}

Reference:
Functional Programming Principles in Scala