Java - Fork/join, the work-stealing thread pool

[Updated: Oct 24, 2017, Created: Oct 15, 2016]

Fork/Join framework uses work-stealing algorithm. Work stealing is a scheduling strategy where worker threads that have finished their own tasks can steal pending tasks from other threads. In parallel execution, tasks are divided among multiple processors/cores. When a core has no work, it should be assigned a task from another processor's overloaded queue rather than being idle.

In fork/join thread pool, each task monitors whether it has to work for a long time. If yes it forks it's own task. 'Fork' is splitting/decomposing of a task and pushing the half (ideally) of it to the parent pool so that other threads can take up (steal) that split task. The term 'join' is combining of results of the tasks submitted to the pool. Tasks are recursively split into smaller parts, run in parallel and recombined.




Fork/join API

Java 7 introduced Fork/join framework with following API:



ForkJoinPool class

ForkJoinPool class implements ExecutorService. The task submitted to this pool should be subclass of ForkJoinTask.

Since this class is not extending ThreadPoolExecutor, it does not have all those core pool size, queuing, maximum pool size concepts, instead it implements work-stealing algorithm.

This class provides static method to get a common pool instance:

public static ForkJoinPool commonPool()

The pool instance returned by above method is initialized in a static block with default settings. That makes sense because a global pool will be aware what cores are being used through out the application at a given point.

This class also provides public constructors:

/*initializes the pool with default settings*/
ForkJoinPool()

/*With parallelism level i.e. how many cores to be used*/
ForkJoinPool(int parallelism)

/**
* parallelism: the parallelism level.
* factory: the factory for creating new threads.
* handler: error handler
* asyncMode: If true event-style asynchronous tasks are used.
*                 By default it's false which is good for normal commutations.
*/
 ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory,
                        UncaughtExceptionHandler handler,
                        boolean asyncMode)


To submit/invoke main task we have to call one of these methods:

  • invoke(ForkJoinTask): this calls waits for task to finish.
  • execute(For;JoinTask): async call.
  • submit(ForkJoinTask): async call which returns Future


ForkJoinTask class

This is an abstract base class representing a task which has ability to split itself (fork) and join back. followings are the important methods:

ForkJoinTask#fork() method submits this task to the pool to run it asynchronously. We call this method on the new task instance which we want to split from a main task. We will see shortly in an example how to do that.

ForkJoinTask#join() method blocks and returns the result of the computation when it is done.

We can extend one of following three abstract subclasses of ForkJoinTask:

  • RecursiveAction: This task do not return any recursive result. We have to override it's only one abstract method.
    protected abstract void compute()
    compute() method implementation performs the application specific computational task..


  • RecursiveTask<V>: The subclass returns a computed value.
    It has one abstract method to be implemented by us:
    protected abstract V compute()


  • CountedCompleter<V>: This class has been introduced in Java 8.
    This class can keep track of pending tasks , we have to call setPendingCount or addToPendingCount for each new task that is forked within 'this' task.

    On completion each task's onCompletion gets called (we have to override this method in our task class). We also have to call tryComplete() or propagateCompletion() when our task finishes so that the pending count can internally decrease and things can work as expected.

    CountedCompleter also has concept of tree arrangement of tasks where each new spawn task (forked) has 'this' parent. We can access the parent by calling CountedCompleter#getCompleter().

    CountedCompleter may optionally return a result by overriding getRawResult. The value is returned from join() or invoke() on the caller side. If we don't want to return values then extend it as CountedCompleter<Void>.

    Here's how everything works : we call addToPendingCount() for every fork(). In the compute(), when done, we call tryComplete(). When the count is zero, the method onCompletion() is called. When the main parent task's pending count becomes zero the main invoke call returns.


    This class has more control over the task completion than the other two but can be error prone if not used properly.

    We have to implement the following method:
    protected abstract V compute()
    Optionally we can override this method
    public void onCompletion (CountedCompleter<?> caller){}
    


RecursiveAction Example

In this example we are calculating the factorials of list of numbers:

  public class FactorialTask extends RecursiveAction {
    // in real app this could be in thousands
    private static int SEQUENTIAL_THRESHOLD = 5;
    private List<BigInteger> integerList;

    private FactorialTask (List<BigInteger> integerList) {
        this.integerList = integerList;
    }

    @Override
    protected void compute () {
       if (integerList.size() <= SEQUENTIAL_THRESHOLD) {
           showFactorials();
       } else {
           //splitting
           int middle = integerList.size() / 2;
           List<BigInteger> newList = integerList.subList(middle, integerList.size());
           integerList = integerList.subList(0, middle);

           FactorialTask task = new FactorialTask(newList);
           //fork() method returns immediately but spawn a new thread for the task
           task.fork();
           this.compute();
        }
    }

     private void showFactorials () {
          for (BigInteger i : integerList) {
         //prints the factorial of each number
     }
 }

Submitting the initial task:

 public static void main (String[] args) {
        List<BigInteger> list = ... //creating a very big list
        ForkJoinPool.commonPool().invoke(new FactorialTask(list));
    }

In above example we can pass an instance of Spliterator instead of list. That will make our example list more abstract and also we could use Spliterator#trySplit() for splitting instead of doing that ourselves. An example is included in the code project at the bottom.



RecursiveTask Example

As mentioned above this RecursiveTask#compute method returns a value. In following example, we are calculating sum of all factorials:

 public class FactorialTask extends RecursiveTask<BigInteger> {
     private static int SEQUENTIAL_THRESHOLD = 5;
     private List<BigInteger> integerList;

     private FactorialTask (List<BigInteger> integerList) {
         this.integerList = integerList;
     }

     @Override
     protected BigInteger compute () {
        if (integerList.size() <= SEQUENTIAL_THRESHOLD) {
            return sumFactorials();
        } else {
            int middle = integerList.size() / 2;
            List<BigInteger> newList = integerList.subList(middle, integerList.size());
            integerList = integerList.subList(0, middle);
            FactorialTask task = new FactorialTask(newList);
            task.fork();
            BigInteger thisSum = this.compute();
            BigInteger thatSum = task.join();
            return thisSum.add(thatSum);
        }
     }

     private BigInteger sumFactorials () {
        BigInteger sum = BigInteger.ZERO;
        for (BigInteger i : integerList) {
            sum = sum.add(CalcUtil.calculateFactorial(i));
        }
        return sum;
     }
 }

Submitting the initial task is same as last example.



CountedCompleter Example

We are modifying our first example to use that with CountedCompleter

public class FactorialTask extends CountedCompleter<Void> {

    private static int SEQUENTIAL_THRESHOLD = 5;
    private List<BigInteger> integerList;

    private FactorialTask (CountedCompleter<Void> parent,
                        List<BigInteger> integerList) {
       super(parent);
       this.integerList = integerList;
    }


    @Override
    public void compute () {
       if (integerList.size() <= SEQUENTIAL_THRESHOLD) {
            sumFactorials();
            //there must be one tryComplete call for each addToPendingCount(1);
            tryComplete();
       } else {
            int middle = integerList.size() / 2;
            List<BigInteger> newList = integerList.subList(middle, integerList.size());
            integerList = integerList.subList(0, middle);
            // add 1 because we are going to fork just one task
            addToPendingCount(1);
            FactorialTask task = new FactorialTask(this, result, newList);
            task.fork();
            this.compute();
       }
    }

    @Override
    public void onCompletion (CountedCompleter<?> caller) {
        if (caller == this) {
            //do something e.g. clean up or do some final computation etc;
        }
    }

    private void showFactorial () {
        for (BigInteger i : integerList) {
            BigInteger factorial = CalcUtil.calculateFactorial(i);
              // println factorial
        }
    }
}
 public static void main (String[] args) {
        List<BigInteger> list = ...

        ForkJoinPool.commonPool().invoke(
                            new FactorialTask(null, list));

    }



Example Project

Dependencies and Technologies Used :

  • JDK 1.8
  • Maven 3.0.4

Fork And Join Examples Select All Download
  • fork-and-join-examples
    • src
      • main
        • java
          • com
            • logicbig
              • example

See Also