Is CSVSequenceRecordReader creating compatible dataset for training LSTM network? - deeplearning4j

I want to train a simple LSTM network but I got the exception
java.lang.IllegalStateException: C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order and not a view. C (result) array: [Rank: 2,Offset: 0 Order: f Shape: [10,1], stride: [1,10]]
I'm training a simple NN with a single LSTM cell and a single output cell for regression.
I created a training dataset of just 10 samples with variable sequence length (from 5 to 10) in csv files, each sample consists of just one value for the input and one value for the output.
I created a SequenceRecordReaderDataSetIterator from a CSVSequenceRecordReader.
When I train my network the code throws the exception.
I tried generating random dataset coding the dataset iterator directly with 'f shape' INDarray and the code runs without error.
So the problem seems to be the shape of tensors created by CSVSequenceRecordReader.
Does anyone have this problems?
SingleFileTimeSeriesDataReader.java
package org.mmarini.lstmtest;
import java.io.IOException;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
/**
*
*/
public class SingleFileTimeSeriesDataReader {
private final int miniBatchSize;
private final int numPossibleLabels;
private final boolean regression;
private final String filePattern;
private final int maxFileIdx;
private final int minFileIdx;
private final int numInputs;
/**
*
* #param filePattern
* #param minFileIdx
* #param maxFileIdx
* #param numInputs
* #param numPossibleLabels
* #param miniBatchSize
* #param regression
*/
public SingleFileTimeSeriesDataReader(final String filePattern, final int minFileIdx, final int maxFileIdx,
final int numInputs, final int numPossibleLabels, final int miniBatchSize, final boolean regression) {
this.miniBatchSize = miniBatchSize;
this.numPossibleLabels = numPossibleLabels;
this.regression = regression;
this.filePattern = filePattern;
this.maxFileIdx = maxFileIdx;
this.minFileIdx = minFileIdx;
this.numInputs = numInputs;
}
/**
*
* #return
* #throws IOException
* #throws InterruptedException
*/
public DataSetIterator apply() throws IOException, InterruptedException {
final SequenceRecordReader reader = new CSVSequenceRecordReader(0, ",");
reader.initialize(new NumberedFileInputSplit(filePattern, minFileIdx, maxFileIdx));
final DataSetIterator iter = new SequenceRecordReaderDataSetIterator(reader, miniBatchSize, numPossibleLabels,
numInputs, regression);
return iter;
}
}
TestConfBuilder.java
/**
*
*/
package org.mmarini.lstmtest;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
/**
* #author mmarini
*
*/
public class TestConfBuilder {
private final int noInputUnits;
private final int noOutputUnits;
private final int noLstmUnits;
/**
*
* #param noInputUnits
* #param noOutputUnits
* #param noLstmUnits
*/
public TestConfBuilder(final int noInputUnits, final int noOutputUnits, final int noLstmUnits) {
super();
this.noInputUnits = noInputUnits;
this.noOutputUnits = noOutputUnits;
this.noLstmUnits = noLstmUnits;
}
/**
*
* #return
*/
public MultiLayerConfiguration build() {
final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
final LSTM lstmLayer = new LSTM.Builder().units(noLstmUnits).nIn(noInputUnits).activation(Activation.TANH)
.build();
final RnnOutputLayer outLayer = new RnnOutputLayer.Builder(LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR)
.activation(Activation.IDENTITY).nOut(noOutputUnits).nIn(noLstmUnits).build();
final MultiLayerConfiguration conf = builder.list(lstmLayer, outLayer).build();
return conf;
}
}
TestTrainingTest .java
package org.mmarini.lstmtest;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
class TestTrainingTest {
private static final int MINI_BATCH_SIZE = 10;
private static final int NUM_LABELS = 1;
private static final boolean REGRESSION = true;
private static final String SAMPLES_FILE = "src/test/resources/datatest/sample_%d.csv";
private static final int MIN_INPUTS_FILE_IDX = 0;
private static final int MAX_INPUTS_FILE_IDX = 9;
private static final int NUM_INPUTS_COLUMN = 1;
private static final int NUM_HIDDEN_UNITS = 1;
DataSetIterator createData() {
final double[][][] featuresAry = new double[][][] { { { 0.5, 0.2, 0.5 } }, { { 0.5, 1.0, 0.0 } } };
final double[] featuresData = ArrayUtil.flattenDoubleArray(featuresAry);
final int[] featuresShape = new int[] { 2, 1, 3 };
final INDArray features = Nd4j.create(featuresData, featuresShape, 'c');
final double[][][] labelsAry = new double[][][] { { { 1.0, -1.0, 1.0 }, { 1.0, -1.0, -1.0 } } };
final double[] labelsData = ArrayUtil.flattenDoubleArray(labelsAry);
final int[] labelsShape = new int[] { 2, 1, 3 };
final INDArray labels = Nd4j.create(labelsData, labelsShape, 'c');
final INDArrayDataSetIterator iter = new INDArrayDataSetIterator(
Arrays.asList(new Pair<INDArray, INDArray>(features, labels)), 2);
System.out.println(iter.inputColumns());
return iter;
}
private String file(String template) {
return new File(".", template).getAbsolutePath();
}
#Test
void testBuild() throws IOException, InterruptedException {
final SingleFileTimeSeriesDataReader reader = new SingleFileTimeSeriesDataReader(file(SAMPLES_FILE),
MIN_INPUTS_FILE_IDX, MAX_INPUTS_FILE_IDX, NUM_INPUTS_COLUMN, NUM_LABELS, MINI_BATCH_SIZE, REGRESSION);
final DataSetIterator data = reader.apply();
assertThat(data.inputColumns(), equalTo(NUM_INPUTS_COLUMN));
assertThat(data.totalOutcomes(), equalTo(NUM_LABELS));
final TestConfBuilder builder = new TestConfBuilder(NUM_INPUTS_COLUMN, NUM_LABELS, NUM_HIDDEN_UNITS);
final MultiLayerConfiguration conf = builder.build();
final MultiLayerNetwork net = new MultiLayerNetwork(conf);
assertNotNull(net);
net.init();
net.fit(data);
}
}
I expect not to throw any exception but I got the following exception:
java.lang.IllegalStateException: C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order and not a view. C (result) array: [Rank: 2,Offset: 0 Order: f Shape: [10,1], stride: [1,10]]
at org.nd4j.base.Preconditions.throwStateEx(Preconditions.java:641)
at org.nd4j.base.Preconditions.checkState(Preconditions.java:304)
at org.nd4j.linalg.factory.Nd4j.gemm(Nd4j.java:980)
at org.deeplearning4j.nn.layers.recurrent.LSTMHelpers.backpropGradientHelper(LSTMHelpers.java:696)
at org.deeplearning4j.nn.layers.recurrent.LSTM.backpropGradientHelper(LSTM.java:122)
at org.deeplearning4j.nn.layers.recurrent.LSTM.backpropGradient(LSTM.java:93)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.calcBackpropGradients(MultiLayerNetwork.java:1826)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.computeGradientAndScore(MultiLayerNetwork.java:2644)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.computeGradientAndScore(MultiLayerNetwork.java:2587)
at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:160)
at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:63)
at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fitHelper(MultiLayerNetwork.java:1602)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit(MultiLayerNetwork.java:1521)
at org.mmarini.lstmtest.TestTrainingTest.testBuild(TestTrainingTest.java:77)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at org.junit.platform.commons.util.ReflectionUtils.invokeMethod(ReflectionUtils.java:532)
at org.junit.jupiter.engine.execution.ExecutableInvoker.invoke(ExecutableInvoker.java:115)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.lambda$invokeTestMethod$6(TestMethodTestDescriptor.java:171)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:72)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.invokeTestMethod(TestMethodTestDescriptor.java:167)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:114)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:59)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$4(NodeTestTask.java:108)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:72)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:98)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:74)
at java.util.ArrayList.forEach(ArrayList.java:1257)
at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:38)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$4(NodeTestTask.java:112)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:72)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:98)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:74)
at java.util.ArrayList.forEach(ArrayList.java:1257)
at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:38)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$4(NodeTestTask.java:112)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:72)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:98)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:74)
at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.submit(SameThreadHierarchicalTestExecutorService.java:32)
at org.junit.platform.engine.support.hierarchical.HierarchicalTestExecutor.execute(HierarchicalTestExecutor.java:57)
at org.junit.platform.engine.support.hierarchical.HierarchicalTestEngine.execute(HierarchicalTestEngine.java:51)
at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:220)
at org.junit.platform.launcher.core.DefaultLauncher.lambda$execute$6(DefaultLauncher.java:188)
at org.junit.platform.launcher.core.DefaultLauncher.withInterceptedStreams(DefaultLauncher.java:202)
at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:181)
at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:128)
at org.eclipse.jdt.internal.junit5.runner.JUnit5TestReference.run(JUnit5TestReference.java:89)
at org.eclipse.jdt.internal.junit.runner.TestExecution.run(TestExecution.java:41)
at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:541)
at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:763)
at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.run(RemoteTestRunner.java:463)
at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.main(RemoteTestRunner.java:209)

Please see the DL4J Gitter community: https://gitter.im/deeplearning4j/deeplearning4j

Related

proper video streaming with rxjava

To handle a video stream from a webcam (delivered by opencv) i am considering to use RxJava.
I am hoping to achieve the following:
being able to control the frames per second to be delivered
to be able to handle different inputs - e.g. a life webcam, a video or even a still picture
being able to switch to a picture-by-picture handling under the control of a gui
I have been experimenting a bit with RxJava but i am confused about the debounce, throttleFirst and async operators
Examples like https://stackoverflow.com/a/48723331/1497139 show some code but I am missing more detailed explanation.
Where could I find a decent example for video processing or something similar along the needs mentioned above?
The code below does some non async logic at this time - please let me know if i could build on it:
ImageFetcher
import org.opencv.core.Mat;
import org.opencv.videoio.VideoCapture;
import rx.Observable;
import rx.functions.Action1;
import rx.functions.Func0;
import rx.functions.Func1;
/**
* fetcher for Images
*
*/
public class ImageFetcher {
// OpenCV video capture
private VideoCapture capture = new VideoCapture();
private String source;
protected int frameIndex;
public int getFrameIndex() {
return frameIndex;
}
/**
* fetch from the given source
*
* #param source
* - the source to fetch from
*/
public ImageFetcher(String source) {
this.source = source;
}
/**
* try opening my source
*
* #return true if successful
*/
public boolean open() {
boolean ret = this.capture.open(source);
frameIndex=0;
return ret;
}
/**
* fetch an image Matrix
*
* #return - the image fetched
*/
public Mat fetch() {
if (!this.capture.isOpened()) {
boolean ret = this.open();
if (!ret) {
String msg = String.format(
"Trying to fetch image from unopened VideoCapture and open %s failed",
source);
throw new IllegalStateException(msg);
}
}
final Mat frame = new Mat();
this.capture.read(frame);
frameIndex++;
return !frame.empty() ? frame : null;
}
#Override
protected void finalize() throws Throwable {
super.finalize();
}
/**
* convert me to an observable
* #return a Mat emitting Observable
*/
public Observable<Mat> toObservable() {
// Resource creation.
Func0<VideoCapture> resourceFactory = () -> {
VideoCapture capture = new VideoCapture();
capture.open(source);
return capture;
};
// Convert to observable.
Func1<VideoCapture, Observable<Mat>> observableFactory = capture -> Observable
.<Mat> create(subscriber -> {
boolean hasNext = true;
while (hasNext) {
final Mat frame = this.fetch();
hasNext = frame!=null && frame.rows()>0 && frame.cols()>0;
if (hasNext) {
String msg = String.format("->%6d:%4dx%d", frameIndex, frame.cols(), frame.rows());
System.out.println(msg);
subscriber.onNext(frame);
}
}
subscriber.onCompleted();
});
// Disposal function.
Action1<VideoCapture> dispose = VideoCapture::release;
return Observable.using(resourceFactory, observableFactory, dispose);
}
}
ImageSubscriber
import org.opencv.core.Mat;
import rx.Subscriber;
public class ImageSubscriber extends Subscriber<Mat> {
public Throwable error;
public int cols = 0;
public int rows=0;
public int frameIndex=0;
public boolean completed = false;
public boolean debug = false;
#Override
public void onCompleted() {
completed = true;
}
#Override
public void onError(Throwable e) {
error = e;
}
#Override
public void onNext(Mat mat) {
cols = mat.cols();
rows = mat.rows();
frameIndex++;
if (cols==0 || rows==0)
System.err.println("invalid frame "+frameIndex);
if (debug) {
String msg = String.format("%6d:%4dx%d", frameIndex, cols, rows);
System.out.println(msg);
}
}
};

How to Batch By N Elements in Streaming Pipeline With Small Bundles?

I've implemented batching by N elements as described in this answer:
Can datastore input in google dataflow pipeline be processed in a batch of N entries at a time?
package com.example.dataflow.transform;
import com.example.dataflow.event.ClickEvent;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.joda.time.Instant;
import java.util.ArrayList;
import java.util.List;
public class ClickToClicksPack extends DoFn> {
public static final int BATCH_SIZE = 10;
private List accumulator;
#StartBundle
public void startBundle() {
accumulator = new ArrayList(BATCH_SIZE);
}
#ProcessElement
public void processElement(ProcessContext c) {
ClickEvent clickEvent = c.element();
accumulator.add(clickEvent);
if (accumulator.size() >= BATCH_SIZE) {
c.output(accumulator);
accumulator = new ArrayList(BATCH_SIZE);
}
}
#FinishBundle
public void finishBundle(FinishBundleContext c) {
if (accumulator.size() > 0) {
ClickEvent clickEvent = accumulator.get(0);
long time = clickEvent.getClickTimestamp().getTime();
c.output(accumulator, new Instant(time), GlobalWindow.INSTANCE);
}
}
}
But when I run pipeline in streaming mode there are a lot of batches with just 1 or 2 elements. As I understand it's because of small bundles size. After running for a day average number of elements in batch is roughly 4. I really need it to be closer to 10 for better performance of the next steps.
Is there a way to control bundles size?
Or should I use "GroupIntoBatches" transform for this purpose. In this case it's not clear for me, what should be selected as a key.
UPDATE:
is it a good idea to use java thread id or VM hostname for a key to apply "GroupIntoBatches" transform?
I've ended up doing composite transform with "GroupIntoBatches" inside.
The following answer contains recommendations regarding key selection:
https://stackoverflow.com/a/44956702/4888849
In my current implementation I'm using random keys to achieve parallelism and I'm windowing events in order to emit results regularly even if there are less then BATCH_SIZE events by one key.
package com.example.dataflow.transform;
import com.example.dataflow.event.ClickEvent;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupIntoBatches;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.joda.time.Duration;
import java.util.Random;
/**
* Batch clicks into packs of BATCH_SIZE size
*/
public class ClickToClicksPack extends PTransform, PCollection>> {
public static final int BATCH_SIZE = 10;
// Define window duration.
// After window's end - elements are emitted even if there are less then BATCH_SIZE elements
public static final int WINDOW_DURATION_SECONDS = 1;
private static final int DEFAULT_SHARDS_NUMBER = 20;
// Determine possible parallelism level
private int shardsNumber = DEFAULT_SHARDS_NUMBER;
public ClickToClicksPack() {
super();
}
public ClickToClicksPack(int shardsNumber) {
super();
this.shardsNumber = shardsNumber;
}
#Override
public PCollection> expand(PCollection input) {
return input
// assign keys, as "GroupIntoBatches" works only with key-value pairs
.apply(ParDo.of(new AssignRandomKeys(shardsNumber)))
.apply(Window.into(FixedWindows.of(Duration.standardSeconds(WINDOW_DURATION_SECONDS))))
.apply(GroupIntoBatches.ofSize(BATCH_SIZE))
.apply(ParDo.of(new ExtractValues()));
}
/**
* Assigns to clicks random integer between zero and shardsNumber
*/
private static class AssignRandomKeys extends DoFn> {
private int shardsNumber;
private Random random;
AssignRandomKeys(int shardsNumber) {
super();
this.shardsNumber = shardsNumber;
}
#Setup
public void setup() {
random = new Random();
}
#ProcessElement
public void processElement(ProcessContext c) {
ClickEvent clickEvent = c.element();
KV kv = KV.of(random.nextInt(shardsNumber), clickEvent);
c.output(kv);
}
}
/**
* Extract values from KV
*/
private static class ExtractValues extends DoFn>, Iterable> {
#ProcessElement
public void processElement(ProcessContext c) {
KV> kv = c.element();
c.output(kv.getValue());
}
}
}

How to extract date from tweets with lucene?

I have some tweets that I have already indexed "TweetIndexer" with lucene knowing that each tweet contains an ID, USER, TEXT, and DATE.
I want to only retrieve the date with another class "TweetSearcher" how to proceed?
Example of tweet:
"0","1467811372","Mon Apr 06 22:20:00 PDT 2009","NO_QUERY","joy_wolf","#Kwesidei not the whole crew ".
This my class TweetIndexer:
import org.apache.lucene.analysis.core.KeywordAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.Version;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
public class TweetIndexer {
protected static final String COMMA = "\",\"";
protected static final String POLARITY = "polarity";
protected static final String ID = "id";
protected static final String DATE = "date";
protected static final String QUERY = "query";
protected static final String USER = "user";
protected static final String TEXT = "text";
public static void main(String[] args) throws Exception {
try {
String indexDir = "D:\\tweet\\index";
String dataFile = "D:\\tweet\\collection\\tweets.csv";
TweetIndexer tweetIndexer = new TweetIndexer();
long start = System.currentTimeMillis();
int count = tweetIndexer.index(new File(indexDir), new File(dataFile));
System.out.print(String.format("Indexed %d documents in %d seconds", count, (System.currentTimeMillis() - start) / 1000));
}
catch (Exception e) {
System.out.println("Usage: java TweetIndexer <index directory> <csv data file>");
}
}
private int index(File indexDir, File dataFile) throws Exception {
IndexWriter indexWriter = new IndexWriter(
FSDirectory.open(indexDir),
new IndexWriterConfig(Version.LUCENE_44, new KeywordAnalyzer()));
int count = indexFile(indexWriter, dataFile);
indexWriter.close();
return count;
}
private int indexFile(IndexWriter indexWriter, File dataFile) throws IOException {
FieldType fieldType = new FieldType();
fieldType.setStored(true);
fieldType.setIndexed(true);
BufferedReader bufferedReader = new BufferedReader(new FileReader(dataFile));
String line = "";
int count = 0;
while ((line = bufferedReader.readLine()) != null) {
// Hack to ignore commas within elements in csv (so we can split on "," rather than just ,)
line = line.substring(1, line.length() - 1);
String[] tweetInfo = line.split(COMMA);
Document document = new Document();
document.add(new Field(POLARITY, tweetInfo[0], fieldType));
document.add(new Field(ID, tweetInfo[1], fieldType));
document.add(new Field(DATE, tweetInfo[2], fieldType));
document.add(new Field(QUERY, tweetInfo[3], fieldType));
document.add(new StringField(USER, tweetInfo[4], Field.Store.YES));
document.add(new StringField(TEXT, tweetInfo[5], Field.Store.YES));
indexWriter.addDocument(document);
count++;
}
return count;
}
}
And this is short java code for my class TweetSearcher:
public class TweetSearcher {
public static void main(String[] args) throws Exception {
try {
String indexDir = "D:\\tweet\\index";
int numHits = Integer.parseInt("3");
TweetSearcher tweetSearcher = new TweetSearcher();
tweetSearcher.dateSearch(new File(indexDir), numHits);
private void dateSearch(File indexDir, int numHits) throws Exception {
System.out.println("Find dates:");
Directory directory = FSDirectory.open(indexDir);
DirectoryReader directoryReader = DirectoryReader.open(directory);
IndexSearcher indexSearcher = new IndexSearcher(directoryReader);

The best classifier model for my data set in Spark

I am a new bee in Spark and ML and I have a task that should be implemented by Apache Spark API.
Some sample rows of my data are:
298,217756,468,0,363,0,0,14,0,11,0,0,894,cluster3
299,219413,25,1364,261,15,0,1,11,5,1,0,1760.5,cluster5
300,223153,1650,8673,2215,282,0,43,120,37,7,0,12853,cluster1
and I need to train a classifier after which, its model will predict the cluster in any arbitrary incoming row. For example the model should predict the '?' in the following row:
318,240747,875,0,0,0,0,8,0,0,0,0,875,?
So I need to know what type of Spark Datatype, Classifier and so on should I use? How should I predict the '?' ?
Any help is appreciated!
Ok I solved the issue :-) just posting the answer for other interested users.
The sample data is
60,236,178,0,0,4,15,16,0,0,575.00,5
1500,0,0,0,0,5,0,0,0,0,1500.00,5
50,2072,248,0,0,1,56,7,0,0,2658.50,5
package spark;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.feature.HashingTF;
import org.apache.spark.mllib.regression.LabeledPoint;
import scala.Tuple2;
import scala.actors.threadpool.Arrays;
import java.text.DecimalFormat;
/**
*/
public class NaiveBayesTest {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("NaiveBayes Example").set("spark.driver.allowMultipleContexts", "true").set("hadoop.version","hadoop-2.4");
conf.setMaster("local[*]");
JavaSparkContext sc = new JavaSparkContext(conf);
String path = "resources/clustering-Result-without-index-id.csv";
JavaRDD<String> data = sc.textFile(path);
final HashingTF tf = new HashingTF(10000);
// Split initial RDD into two... [60% training data, 40% testing data].
JavaRDD<LabeledPoint> mainData = data.map(
new Function<String , LabeledPoint>() {
#Override
public LabeledPoint call( String line) throws Exception {
String[] parts = line.split(",");
Double[] v = new Double[parts.length - 1];
for (int i = 0; i < parts.length - 1 ; i++){
v[i] = Double.parseDouble(parts[i]);
}
return new LabeledPoint(Double.parseDouble(parts[parts.length-1]),tf.transform(Arrays.asList(v)));
}
});
JavaRDD<LabeledPoint> training = mainData.sample(false, 0.9, 111L);
training.cache();
JavaRDD<LabeledPoint> test = mainData.subtract(training);
test.cache();
NaiveBayesModel model = NaiveBayes.train(training.rdd(), 23.0);
JavaPairRDD<Double, Double> predictionAndLabel =
test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
#Override public Tuple2<Double, Double> call(LabeledPoint p) {
double cluster = model.predict(p.features());
String b = (cluster == p.label()) ? " ------> correct." : "";
System.out.println("predicted : "+cluster+ " , actual : " + p.label() + b);
return new Tuple2<Double, Double>(cluster, p.label());
}
});
double accuracy = predictionAndLabel.filter(
new Function<Tuple2<Double, Double>, Boolean>() {
#Override
public Boolean call(Tuple2<Double, Double> pl) {
return pl._1().equals(pl._2());
}
}).count() / (double) test.count();
System.out.println("accuracy is " + new DecimalFormat("#.000").format(accuracy * 100) + "%");
LabeledPoint point = new LabeledPoint(3,tf.transform(Arrays.asList(new String[]{"0,825,0,0,0,0,1,0,0,0,2180"})));
double d = model.predict(point.features());
System.out.println("predicted : "+d+ " , actual : " + point.label());
model.save(sc.sc(), "myModelPath");
NaiveBayesModel sameModel = NaiveBayesModel.load(sc.sc(), "myModelPath");
sameModel.labels();
}
}

win:length(2) is fired after first event

I made a very simple test gui based on this brilliant article about getting started with Esper.
What surprises me is that this query is validated to true after the very first tick event is sent, if the price is above 6.
select * from StockTick(symbol='AAPL').win:length(2) having avg(price) > 6.0
As far as I understand, win:length(2) needs TWO ticks before an event is fired, or am I wrong?
Here is a SSCCE for this question, just press the "Create Tick Event" button and you will see the StockTick Event being fired at once.
It needs the following jars which comes bundled with Esper
esper\lib\antlr-runtime-3.2.jar
esper\lib\cglib-nodep-2.2.jar
esper\lib\commons-logging-1.1.1.jar
esper\lib\esper_3rdparties.license
esper\lib\log4j-1.2.16.jar
esper-4.11.0.jar
import javax.swing.JFrame;
import javax.swing.JSplitPane;
import javax.swing.SwingUtilities;
import java.awt.BorderLayout;
import java.awt.Dimension;
import javax.swing.JButton;
import javax.swing.JScrollPane;
import java.awt.TextArea;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.util.Date;
import java.util.Random;
import javax.swing.JRadioButton;
import javax.swing.JPanel;
import java.awt.GridLayout;
import javax.swing.JTextArea;
import javax.swing.ScrollPaneConstants;
import com.espertech.esper.client.Configuration;
import com.espertech.esper.client.EPAdministrator;
import com.espertech.esper.client.EPRuntime;
import com.espertech.esper.client.EPServiceProvider;
import com.espertech.esper.client.EPServiceProviderManager;
import com.espertech.esper.client.EPStatement;
import com.espertech.esper.client.EventBean;
import com.espertech.esper.client.UpdateListener;
import javax.swing.JTextField;
public class Tester extends JFrame {
/**
*
*/
private static final long serialVersionUID = 1L;
JButton createRandomValueEventButton;
private JPanel panel;
private JPanel southPanel;
private JPanel centerPanel;
private static JTextArea centerTextArea;
private static JTextArea southTextArea;
private static Random generator = new Random();
private EPRuntime cepRT;
private JSplitPane textSplitPane;
private JButton btnNewButton;
private static JTextField priceTextField;
public Tester() {
getContentPane().setLayout(new BorderLayout(0, 0));
JSplitPane splitPane = new JSplitPane();
createRandomValueEventButton = new JButton("Create Tick Event With Random Price");
splitPane.setLeftComponent(createRandomValueEventButton);
createRandomValueEventButton.addActionListener(new ActionListener() {
#Override
public void actionPerformed(ActionEvent e) {
createTickWithRandomPrice();
}
});
panel = new JPanel();
splitPane.setRightComponent(panel);
panel.setLayout(new GridLayout(1, 0, 0, 0));
btnNewButton = new JButton("Create Tick Event");
panel.add(btnNewButton);
btnNewButton.addActionListener(new ActionListener() {
#Override
public void actionPerformed(ActionEvent e) {
createTick();
}
});
priceTextField = new JTextField();
priceTextField.setText(new Integer(10).toString());
panel.add(priceTextField);
priceTextField.setColumns(4);
getContentPane().add(splitPane, BorderLayout.NORTH);
textSplitPane = new JSplitPane();
textSplitPane.setOrientation(JSplitPane.VERTICAL_SPLIT);
getContentPane().add(textSplitPane, BorderLayout.CENTER);
centerPanel = new JPanel();
centerPanel.setLayout(new BorderLayout(0, 0));
JScrollPane centerTextScrollPane = new JScrollPane();
centerTextScrollPane.setVerticalScrollBarPolicy(ScrollPaneConstants.VERTICAL_SCROLLBAR_ALWAYS);
centerTextArea = new JTextArea();
centerTextArea.setRows(12);
centerTextScrollPane.setViewportView(centerTextArea);
southPanel = new JPanel();
southPanel.setLayout(new BorderLayout(0, 0));
JScrollPane southTextScrollPane = new JScrollPane();
southTextScrollPane.setVerticalScrollBarPolicy(ScrollPaneConstants.VERTICAL_SCROLLBAR_ALWAYS);
southTextArea = new JTextArea();
southTextArea.setRows(5);
southTextScrollPane.setViewportView(southTextArea);
textSplitPane.setRightComponent(southTextScrollPane);
textSplitPane.setLeftComponent(centerTextScrollPane);
setupCEP();
}
public static void GenerateRandomTick(EPRuntime cepRT) {
double price = (double) generator.nextInt(10);
long timeStamp = System.currentTimeMillis();
String symbol = "AAPL";
Tick tick = new Tick(symbol, price, timeStamp);
System.out.println("Sending tick:" + tick);
centerTextArea.append(new Date().toString()+" Sending tick:" + tick+"\n");
cepRT.sendEvent(tick);
}
public static void GenerateTick(EPRuntime cepRT) {
double price = Double.parseDouble(priceTextField.getText());
long timeStamp = System.currentTimeMillis();
String symbol = "AAPL";
Tick tick = new Tick(symbol, price, timeStamp);
System.out.println("Sending tick:" + tick);
centerTextArea.append(new Date().toString()+" Sending tick: " + tick+"\n");
cepRT.sendEvent(tick);
}
public static void main(String[] args){
Tester tester = new Tester();
tester.setSize(new Dimension(570,500));
tester.setVisible(true);
}
private void createTickWithRandomPrice(){
SwingUtilities.invokeLater(new Runnable() {
public void run() {
GenerateRandomTick(getEPRuntime());
}
});
}
private void createTick(){
SwingUtilities.invokeLater(new Runnable() {
public void run() {
GenerateTick(getEPRuntime());
}
});
}
private void setupCEP(){
Configuration cepConfig = new Configuration();
cepConfig.addEventType("StockTick", Tick.class.getName());
EPServiceProvider cep = EPServiceProviderManager.getProvider("myCEPEngine", cepConfig);
cepRT = cep.getEPRuntime();
EPAdministrator cepAdm = cep.getEPAdministrator();
EPStatement cepStatement = cepAdm.createEPL(
"select * from " +
"StockTick(symbol='AAPL').win:length(2) " +
"having avg(price) > 6.0");
cepStatement.addListener(new CEPListener());
//System.out.println("cepStatement.getText(): "+cepStatement.getText());
}
private EPRuntime getEPRuntime(){
public static class Tick {
String symbol;
Double price;
Date timeStamp;
public Tick(String s, double p, long t) {
symbol = s;
price = p;
timeStamp = new Date(t);
}
public double getPrice() {return price;}
public String getSymbol() {return symbol;}
public Date getTimeStamp() {return timeStamp;}
#Override
public String toString() {
return symbol+" Price: " + price.toString();
}
}
public static class CEPListener implements UpdateListener {
}
Actually aggregation and conditions are independent of how many events are in data window. There are functions you could use to check whether a data window is "filled": the "leaving", "count" or "prevcount" for example.
For anyone interested,
changing the query to this solved the problem
select * from StockTick(symbol='AAPL').win:length_batch(2) having avg(price) > 6.0 and count(*) >= 2
Now an event will be triggered for every consecutive tick with the price higher than 6, in batches of two.

Resources