package com.intel.daal.examples.kmeans;
import com.intel.daal.algorithms.kmeans.*;
import com.intel.daal.algorithms.kmeans.init.*;
import com.intel.daal.data_management.data.NumericTable;
import com.intel.daal.data_management.data_source.DataSource;
import com.intel.daal.data_management.data_source.FileDataSource;
import com.intel.daal.examples.utils.Service;
import com.intel.daal.services.DaalContext;
class KMeansDenseDistributed {
private static final String[] datasetFileNames = {
"../data/distributed/kmeans_dense_1.csv", "../data/distributed/kmeans_dense_2.csv",
"../data/distributed/kmeans_dense_3.csv", "../data/distributed/kmeans_dense_4.csv"};
private static final int nClusters = 20;
private static final int nBlocks = 4;
private static final int nIterations = 5;
private static final int nVectorsInBlock = 2500;
private static DaalContext context = new DaalContext();
public static void main(String[] args) throws java.io.FileNotFoundException, java.io.IOException {
InitDistributedStep2Master initMaster = new InitDistributedStep2Master(context, Double.class,
InitMethod.randomDense, nClusters);
for (int node = 0; node < nBlocks; node++) {
FileDataSource dataSource = new FileDataSource(context, datasetFileNames[node],
DataSource.DictionaryCreationFlag.DoDictionaryFromContext,
DataSource.NumericTableAllocationFlag.DoAllocateNumericTable);
dataSource.loadDataBlock();
NumericTable data = dataSource.getNumericTable();
InitDistributedStep1Local initLocal = new InitDistributedStep1Local(context, Double.class,
InitMethod.randomDense, nClusters, nBlocks * nVectorsInBlock, node * nVectorsInBlock);
initLocal.input.set(InitInputId.data, data);
InitPartialResult initPres = initLocal.compute();
initMaster.input.add(InitDistributedStep2MasterInputId.partialResults, initPres);
}
initMaster.compute();
InitResult initResult = initMaster.finalizeCompute();
NumericTable centroids = initResult.get(InitResultId.centroids);
NumericTable[] assignments = new NumericTable[nBlocks];
NumericTable goalFunction = null;
DistributedStep2Master masterAlgorithm = new DistributedStep2Master(context, Double.class, Method.defaultDense,
nClusters);
for (int it = 0; it < nIterations + 1; it++) {
for (int node = 0; node < nBlocks; node++) {
FileDataSource dataSource = new FileDataSource(context, datasetFileNames[node],
DataSource.DictionaryCreationFlag.DoDictionaryFromContext,
DataSource.NumericTableAllocationFlag.DoAllocateNumericTable);
dataSource.loadDataBlock();
NumericTable data = dataSource.getNumericTable();
DistributedStep1Local algorithm = new DistributedStep1Local(context, Double.class, Method.defaultDense,
nClusters);
if (it == nIterations) {
algorithm.parameter.setAssignFlag(true);
}
algorithm.input.set(InputId.data, data);
algorithm.input.set(InputId.inputCentroids, centroids);
PartialResult pres = algorithm.compute();
if (it == nIterations) {
Result result = algorithm.finalizeCompute();
assignments[node] = result.get(ResultId.assignments);
} else {
masterAlgorithm.input.add(DistributedStep2MasterInputId.partialResults, pres);
}
}
if (it == nIterations)
break;
masterAlgorithm.compute();
Result result = masterAlgorithm.finalizeCompute();
centroids = result.get(ResultId.centroids);
goalFunction = result.get(ResultId.goalFunction);
}
Service.printNumericTable("First 10 cluster assignments from 1st node:", assignments[0], 10);
Service.printNumericTable("First 10 dimensions of centroids:", centroids, 20, 10);
Service.printNumericTable("Goal function value:", goalFunction);
context.dispose();
}
}