java – Parallelizing short computations

Question:

Problem: it is necessary to parallelize the computation of several models N time (N consecutive iterations). The model performs a short computation at a time using the numerical method (Euler / Runge-Kutt or other method). The number of models is N, after the iteration of the calculation of N models, the results of calculations are exchanged, the state of the model changes and thus affects the result of calculations at the next iteration, the sequence of calculating the iterations is important (the model shows how its state will change over time, subject to numerical integration in small time steps ). The following is an example from the bulldozer, tk. in reality it is more difficult:

for (int i = 0; i < N; i++) {
  model1.calculate();
  model2.calculate();

  model2.setU(model1.getU());
  model1.setI(model2.getI());
}

The number of iterations can be large, and you need to manage to simulate up to 100,000 iterations per real second. You can parallelize only within the current iteration, i.e. model1.calculate () and model2.calculate (). So far, we have not been able to write efficient parallelization – it is consistently faster.

Here are my attempts to write parallelization:

package jcscore;

public class Model {
    private static final int SYSTEM_EQUATION_LENGTH = 200;
    private static final int LAST_INDEX = SYSTEM_EQUATION_LENGTH - 1;
    private static final int INPUT_INDEX = 0;
    private static final int OUTPUT_INDEX = 5;
    private final double[] factors = new double[SYSTEM_EQUATION_LENGTH];
    private final double[] state = new double[SYSTEM_EQUATION_LENGTH];

    public Model(double factorInit) {
        for (int i = 0; i < factors.length; i++) {
            factors[i] = i * i * factorInit;
        }
        for (int i = 0; i < state.length; i++) {
            state[i] = i * i * 123.0;
        }
    }

    public void calculate(double timeStep) {
        state[0] = state[LAST_INDEX] + state[0] * factors[0] * timeStep;

        for (int i = LAST_INDEX; i > 0; i--) {
            // dstate / timeStep = state[i-1] * factor
            state[i] = state[i] + state[i - 1] * factors[i] * timeStep;
        }
    }

    public void setInput(double value) {
        state[INPUT_INDEX] = value;
    }

    public double getOutput() {
        return state[OUTPUT_INDEX];
    }
}
package jcscore;

public class ModelExchanger {
    private final Model[] models;
    private final int length;
    private final int lastIndex;

    public ModelExchanger(Model... models) {
        this.models = models;
        this.length = models.length;
        lastIndex = length - 1;
    }

    public void exchange() {
        models[0].setInput(models[lastIndex].getOutput());
        for (int i = 1; i < length; i++) {
            models[i].setInput(models[lastIndex].getOutput());
        }
    }
}
public class BenchmarkMain {
    public static void main(String[] args) throws Exception {
        org.openjdk.jmh.Main.main(args);
    }
}
package jcscore;

import lombok.Data;
import lombok.Getter;
import org.junit.Assert;
import org.junit.Test;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.Threads;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.*;
import java.util.function.Consumer;

public class ModelParallelizationTest
{
        @Benchmark
        @BenchmarkMode(Mode.AverageTime)
        public void testNotParallelization() {
                final double timeStep = 0.000001;
                ActionRunnable task1 = new ActionRunnable(0.0001, timeStep);
                ActionRunnable task2 = new ActionRunnable(0.0002, timeStep);
                ActionRunnable task3 = new ActionRunnable(0.0003, timeStep);
                ModelExchanger modelExchanger = new ModelExchanger(task1.getModel(), task2.getModel(), task3.getModel());

                for (int i = 0; i < 1000; i++) {
                        task1.run();
                        task2.run();
                        task3.run();

                        task1.clear();
                        task2.clear();
                        task3.clear();
                        modelExchanger.exchange();
                }

                Assert.assertEquals(3075.0049, task1.getModel().getOutput(), 0.001);
        }

        @Threads(5)
        @Benchmark
        @BenchmarkMode(Mode.AverageTime)
        public void testServiceExecutor() throws ExecutionException, InterruptedException {
                final double timeStep = 0.000001;
                ActionRunnable task1 = new ActionRunnable(0.0001, timeStep);
                ActionRunnable task2 = new ActionRunnable(0.0002, timeStep);
                ActionRunnable task3 = new ActionRunnable(0.0003, timeStep);
                ModelExchanger modelExchanger = new ModelExchanger(task1.getModel(), task2.getModel(), task3.getModel());

                ExecutorService executor = Executors.newFixedThreadPool(4);

                for (int i = 0; i < 1000; i++)
                {
                        Future future1 = executor.submit(task1);
                        Future future2 = executor.submit(task2);
                        Future future3 = executor.submit(task3);

                        future1.get();
                        future2.get();
                        future3.get();

                        task1.clear();
                        task2.clear();
                        task3.clear();
                        modelExchanger.exchange();
                }

                executor.shutdown();

                Assert.assertEquals(3075.0049, task1.getModel().getOutput(), 0.001);
        }

        @Threads(5)
        @Benchmark
        @BenchmarkMode(Mode.AverageTime)
        public void testForkJoin() {
                final double timeStep = 0.000001;
                ForkJoinAction task1 = new ForkJoinAction(0.0001, timeStep);
                ForkJoinAction task2 = new ForkJoinAction(0.0002, timeStep);
                ForkJoinAction task3 = new ForkJoinAction(0.0003, timeStep);
                ForkJoinAction[] tasks = {task1, task2, task3};
                ModelExchanger modelExchanger = new ModelExchanger(task1.getModel(), task2.getModel(), task3.getModel());

                ForkJoinPool forkJoinPool = new ForkJoinPool();
                ForkJoinActions forkJoinActions = new ForkJoinActions(tasks);

                for (int i = 0; i < 1000; i++)
                {
                        forkJoinPool.invoke(forkJoinActions);

                        task1.clear();
                        task2.clear();
                        task3.clear();
                        modelExchanger.exchange();
                }

                Assert.assertEquals(3075.0000049, task1.getModel().getOutput(), 0.001);
        }

        @Threads(5)
        @Benchmark
        @BenchmarkMode(Mode.AverageTime)
        public void testStreamParallel()
        {
                final double timeStep = 0.000001;
                ActionRunnable task1 = new ActionRunnable(0.0001, timeStep);
                ActionRunnable task2 = new ActionRunnable(0.0002, timeStep);
                ActionRunnable task3 = new ActionRunnable(0.0003, timeStep);
                ModelExchanger modelExchanger = new ModelExchanger(task1.getModel(), task2.getModel(), task3.getModel());

                List<ActionRunnable> tasks = Arrays.asList(task1, task2, task3);
                Consumer<ActionRunnable> taskConsumer = ActionRunnable::run;

                for (int i = 0; i < 1000; i++)
                {
                        tasks.parallelStream().forEach(taskConsumer);

                        task1.clear();
                        task2.clear();
                        task3.clear();
                        modelExchanger.exchange();
                }

                Assert.assertEquals(3075.0049, task1.getModel().getOutput(), 0.001);
        }

        @Getter
        public static class ActionRunnable implements Runnable
        {
                private final Model model;
                private double timeStep;
                private boolean complete = false;

                public ActionRunnable(double factorInit, double timeStep)
                {
                        this.model = new Model(factorInit);
                        this.timeStep = timeStep;
                }

                public Model getModel() {
                        return model;
                }

                public void clear()
                {
                        complete = false;
                }

                @Override
                public void run()
                {
                        model.calculate(timeStep);
                        complete = true;
                }
        }

        @Getter
        public static class ForkJoinAction extends RecursiveTask<Void>
        {
                private final Model model;
                private double timeStep;
                private boolean complete = false;

                public ForkJoinAction(double factorInit, double timeStep)
                {
                        this.model = new Model(factorInit);
                        this.timeStep = timeStep;
                }

                public Model getModel() {
                        return model;
                }

                public void clear()
                {
                        complete = false;
                }

                @Override
                public Void compute()
                {
                        model.calculate(timeStep);
                        complete = true;
                        return null;
                }
        }

        public static class ForkJoinActions extends RecursiveTask<Void>
        {
                private final ForkJoinAction[] actions;

                public ForkJoinActions(ForkJoinAction[] actions)
                {
                        this.actions = actions;
                }

                @Override
                protected Void compute()
                {
                        final int length = actions.length;
                        for (int i = 0; i < length; i++)
                        {
                                actions[i].fork(); // запустим асинхронно
                        }
                        for (int i = 0; i < length; i++)
                        {
                                actions[i].join(); // дождёмся выполнения задачи
                        }
                        return null;
                }
        }
}

Results:

Result "jcscore.ModelParallelizationTest.testNotParallelization":
  0,001 ±(99.9%) 0,001 s/op [Average]
  (min, avg, max) = (≈ 10⁻³, 0,001, 0,001), stdev = 0,001
  CI (99.9%): [≈ 10⁻³, 0,001] (assumes normal distribution)

Result "jcscore.ModelParallelizationTest.testServiceExecutor":
  0,015 ±(99.9%) 0,001 s/op [Average]
  (min, avg, max) = (0,012, 0,015, 0,019), stdev = 0,002
  CI (99.9%): [0,014, 0,017] (assumes normal distribution)
  
Result "jcscore.ModelParallelizationTest.testForkJoin":
  0,061 ±(99.9%) 0,016 s/op [Average]
  (min, avg, max) = (0,027, 0,061, 0,103), stdev = 0,022
  CI (99.9%): [0,044, 0,077] (assumes normal distribution)
  
Result "jcscore.ModelParallelizationTest.testStreamParallel":
  0,002 ±(99.9%) 0,001 s/op [Average]
  (min, avg, max) = (0,002, 0,002, 0,003), stdev = 0,001
  CI (99.9%): [0,002, 0,002] (assumes normal distribution)

Benchmark                                           Mode  Cnt  Score    Error  Units
ModelParallelizationTest.testNotParallelization  avgt   25  0,001 ±  0,001   s/op
ModelParallelizationTest.testServiceExecutor  avgt   25  0,015 ± 0,001   s/op
ModelParallelizationTest.testForkJoin  avgt   25  0,061 ± 0,016   s/op
ModelParallelizationTest.testStreamParallel  avgt   25  0,002 ±  0,001   s/op

Parallelization needs to have a significant advantage over sequential computation. Any ideas on how to speed up parallelization?

Answer:

the best solution is to use parallelStream ()

or if you are familiar with the forkjoin pool concept, which is essentially the same thing

Scroll to Top