← 返回首页
🔀

Java Fork/Join框架详解:并行分治算法

📂 java ⏱ 5 min 822 words

Java Fork/Join框架详解:并行分治算法

概述

Fork/Join框架是Java 7引入的并行执行框架,专门用于分治算法。它使用工作窃取算法,能够高效地利用多核处理器。

1. Fork/Join基本概念

import java.util.concurrent.*;

public class ForkJoinBasic {
    // RecursiveTask:有返回值
    // RecursiveAction:无返回值
    
    // 示例:计算数组和
    static class SumTask extends RecursiveTask<Long> {
        private static final int THRESHOLD = 10000;
        private long[] array;
        private int start;
        private int end;
        
        public SumTask(long[] array, int start, int end) {
            this.array = array;
            this.start = start;
            this.end = end;
        }
        
        @Override
        protected Long compute() {
            int length = end - start;
            
            if (length <= THRESHOLD) {
                // 直接计算
                long sum = 0;
                for (int i = start; i < end; i++) {
                    sum += array[i];
                }
                return sum;
            }
            
            // 分割任务
            int mid = start + length / 2;
            SumTask leftTask = new SumTask(array, start, mid);
            SumTask rightTask = new SumTask(array, mid, end);
            
            // 异步执行左任务
            leftTask.fork();
            
            // 同步执行右任务
            long rightResult = rightTask.compute();
            
            // 等待左任务完成
            long leftResult = leftTask.join();
            
            return leftResult + rightResult;
        }
    }
    
    public static void main(String[] args) {
        // 创建测试数据
        long[] array = new long[100000];
        for (int i = 0; i < array.length; i++) {
            array[i] = i + 1;
        }
        
        // 创建线程池
        ForkJoinPool pool = new ForkJoinPool();
        
        // 提交任务
        SumTask task = new SumTask(array, 0, array.length);
        long result = pool.invoke(task);
        
        System.out.println("计算结果: " + result);
        System.out.println("期望结果: " + (100000L * 100001L / 2));
        
        pool.shutdown();
    }
}

2. 工作窃取算法

import java.util.concurrent.*;

public class WorkStealingExample {
    static class CustomRecursiveAction extends RecursiveAction {
        private static final int THRESHOLD = 10;
        private int start;
        private int end;
        
        public CustomRecursiveAction(int start, int end) {
            this.start = start;
            this.end = end;
        }
        
        @Override
        protected void compute() {
            int length = end - start;
            
            if (length <= THRESHOLD) {
                // 直接处理
                for (int i = start; i < end; i++) {
                    System.out.println(Thread.currentThread().getName() + 
                        " 处理: " + i);
                }
            } else {
                // 分割任务
                int mid = start + length / 2;
                CustomRecursiveAction left = new CustomRecursiveAction(start, mid);
                CustomRecursiveAction right = new CustomRecursiveAction(mid, end);
                
                invokeAll(left, right);  // 同时执行两个子任务
            }
        }
    }
    
    public static void main(String[] args) {
        ForkJoinPool pool = new ForkJoinPool();
        
        CustomRecursiveAction task = new CustomRecursiveAction(0, 100);
        pool.invoke(task);
        
        pool.shutdown();
    }
}

3. 实际应用示例

并行排序

import java.util.concurrent.*;

public class ParallelMergeSort {
    private static final int THRESHOLD = 1000;
    
    static class MergeSortTask extends RecursiveAction {
        private int[] array;
        private int left;
        private int right;
        
        public MergeSortTask(int[] array, int left, int right) {
            this.array = array;
            this.left = left;
            this.right = right;
        }
        
        @Override
        protected void compute() {
            if (right - left <= THRESHOLD) {
                // 插入排序
                insertionSort(array, left, right);
                return;
            }
            
            int mid = left + (right - left) / 2;
            MergeSortTask leftTask = new MergeSortTask(array, left, mid);
            MergeSortTask rightTask = new MergeSortTask(array, mid, right);
            
            invokeAll(leftTask, rightTask);
            
            merge(array, left, mid, right);
        }
        
        private void insertionSort(int[] array, int left, int right) {
            for (int i = left + 1; i < right; i++) {
                int key = array[i];
                int j = i - 1;
                
                while (j >= left && array[j] > key) {
                    array[j + 1] = array[j];
                    j--;
                }
                array[j + 1] = key;
            }
        }
        
        private void merge(int[] array, int left, int mid, int right) {
            int[] temp = new int[right - left];
            int i = left, j = mid, k = 0;
            
            while (i < mid && j < right) {
                if (array[i] <= array[j]) {
                    temp[k++] = array[i++];
                } else {
                    temp[k++] = array[j++];
                }
            }
            
            while (i < mid) {
                temp[k++] = array[i++];
            }
            
            while (j < right) {
                temp[k++] = array[j++];
            }
            
            System.arraycopy(temp, 0, array, left, temp.length);
        }
    }
    
    public static void main(String[] args) {
        int[] array = new int[100000];
        for (int i = 0; i < array.length; i++) {
            array[i] = (int) (Math.random() * 1000000);
        }
        
        ForkJoinPool pool = new ForkJoinPool();
        
        long start = System.currentTimeMillis();
        MergeSortTask task = new MergeSortTask(array, 0, array.length);
        pool.invoke(task);
        long end = System.currentTimeMillis();
        
        System.out.println("排序完成,耗时: " + (end - start) + "ms");
        
        pool.shutdown();
    }
}

并行搜索

import java.util.concurrent.*;

public class ParallelSearch {
    static class SearchTask extends RecursiveTask<Integer> {
        private int[] array;
        private int target;
        private int start;
        private int end;
        
        public SearchTask(int[] array, int target, int start, int end) {
            this.array = array;
            this.target = target;
            this.start = start;
            this.end = end;
        }
        
        @Override
        protected Integer compute() {
            int length = end - start;
            
            if (length <= 1000) {
                // 线性搜索
                for (int i = start; i < end; i++) {
                    if (array[i] == target) {
                        return i;
                    }
                }
                return -1;
            }
            
            // 分割任务
            int mid = start + length / 2;
            SearchTask leftTask = new SearchTask(array, target, start, mid);
            SearchTask rightTask = new SearchTask(array, target, mid, end);
            
            leftTask.fork();
            int rightResult = rightTask.compute();
            int leftResult = leftTask.join();
            
            if (leftResult != -1) {
                return leftResult;
            }
            return rightResult;
        }
    }
    
    public static void main(String[] args) {
        int[] array = new int[1000000];
        for (int i = 0; i < array.length; i++) {
            array[i] = i;
        }
        
        ForkJoinPool pool = new ForkJoinPool();
        
        int target = 999999;
        SearchTask task = new SearchTask(array, target, 0, array.length);
        int result = pool.invoke(task);
        
        System.out.println("搜索结果: " + result);
        
        pool.shutdown();
    }
}

4. 最佳实践

  1. 设置合适的阈值:根据任务特性设置THRESHOLD
  2. 避免任务过小:任务太小会增加调度开销
  3. 使用invokeAll:同时启动所有子任务
  4. 合理使用ForkJoinPool:避免创建过多线程
  5. 监控线程池状态:定期检查线程池的活跃线程数

总结

Fork/Join框架是Java并行编程的核心工具。掌握其使用方法,可以实现高效的并行分治算法。在实际编程中,要根据任务特性合理设置阈值和线程池参数。