Neural networks in Android, Google ML Kit and not only

So, you have developed and trained your neural network to perform some task (for example, the same object recognition through the camera) and want to embed it in your application on android? Then welcome under the cat!

To begin with, it should be understood that the android is currently able to work only with networks of the TensorFlowLite format, which means we need to carry out some manipulations with the original network. Suppose you already have a trained network in a Keras or Tensorflow framework. You must save the grid in pb format.

Let's start with the case when you write to Tensorflow, then everything is a little easier.

saver = tf.train.Saver()
tf.train.write_graph(session.graph_def, path_to_folder, "net.pb", False)
tf.train.write_graph(session.graph_def, path_to_folder, "net.pbtxt", True)
saver.save(session,path_to_folder+"model.ckpt")

If you are writing to Keras, you need to create a new session object at the beginning of the file where you train the network, save the link to it, and pass it to the set_session function

import keras.backend as K
session = K.get_session()
K.set_session(session)

Great, you saved the network, now you need to convert it to tflite format. To do this, we need to run two small scripts, the first one will “freeze” the network, the second one will already translate into the required format. The essence of the “freeze” is that tf does not store the weights of the layers in the saved pb file, but saves them in special checkpoints. For subsequent conversion to tflite, it is necessary that all information about the neural network be in one file.

freeze_graph --input_binary=false --input_graph=net.pbtxt --output_node_names=result/Softmax --output_graph=frozen_graph.pb --input_checkpoint=model.ckpt

Note that you need to know the name of the output tensor. In tensorflow, you can set it yourself; if using Keras, set the name in the layer constructor

model.add(Dense(10,activation="softmax",name="result"))

In such a case, the name of the tensor usually looks like “result / Softmax”.

If not in your case, you can find the name as follows.

[print(n.name) for n in session.graph.as_graph_def().node]

It remains to run the second script

toco --graph_def_file=frozen-graph.pb --output_file=model.tflite --output_format=TFLITE --inference_type=FLOAT --input_arrays=input_input --output_arrays=result/Softmax --input_shapes=1,784

Hooray! Now you have a TensorFlowLite model in your folder, it’s easy to integrate it correctly into your android application. You can do this with the help of the new-fashioned Firebase ML Kit, but there is another way, about it a bit later. Add a dependency to our gradle file

dependencies {
  // ...
  implementation 'com.google.firebase:firebase-ml-model-interpreter:16.2.0'
}

Now you need to decide whether you will keep the model somewhere on your server or supply it with the application.

Consider the first case: a model on the server. First of all, do not forget to add to the manifest

<uses-permissionandroid:name="android.permission.INTERNET" />

// Создаем объект для задания специальных условий, требуемых для загрузки/обновления модели
FirebaseModelDownloadConditions.Builder conditionsBuilder =
        new FirebaseModelDownloadConditions.Builder().requireWifi();
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
    conditionsBuilder = conditionsBuilder
            .requireCharging();
}
FirebaseModelDownloadConditions conditions = conditionsBuilder.build();
// Создаем объект FirebaseCloudModelSource , задаем имя (должно совпадать с именем модели, загруженной // в консоль Firebase)
FirebaseCloudModelSource cloudSource = new FirebaseCloudModelSource.Builder("my_cloud_model")
        .enableModelUpdates(true)
        .setInitialDownloadConditions(conditions)
        .setUpdatesDownloadConditions(conditions)
        .build();
FirebaseModelManager.getInstance().registerCloudModelSource(cloudSource);

If you are using a model included in the application locally, do not forget to add the following entry to build.gradle file so that the model file does not compress.

android {
    // ...
    aaptOptions {
        noCompress "tflite"
    }
}

After that, by analogy with the model in the cloud, our local neuron needs to be registered.

FirebaseLocalModelSource localSource = new FirebaseLocalModelSource.Builder("my_local_model")
        .setAssetFilePath("mymodel.tflite")
        .build();
FirebaseModelManager.getInstance().registerLocalModelSource(localSource);

The code above assumes that your model is in the assets folder, if it is not, instead of

        .setAssetFilePath("mymodel.tflite")

use

        .seFilePath(filePath)

Then we create new FirebaseModelOptions and FirebaseModelInterpreter objects.

FirebaseModelOptions options = new FirebaseModelOptions.Builder()
        .setCloudModelName("my_cloud_model")
        .setLocalModelName("my_local_model")
        .build();
FirebaseModelInterpreter firebaseInterpreter =
        FirebaseModelInterpreter.getInstance(options);

You can use both local and server-based models at the same time. In this case, the default cloud will be used if it is available, otherwise local.

Almost everything, it remains to create arrays for input / output data, and run!

FirebaseModelInputOutputOptions inputOutputOptions =
    new FirebaseModelInputOutputOptions.Builder()
        .setInputFormat(0, FirebaseModelDataType.BYTE, newint[]{1, 640, 480, 3})
        .setOutputFormat(0, FirebaseModelDataType.FLOAT32, newint[]{1, 784})
        .build();
byte[][][][] input = newbyte[1][640][480][3];
input = getYourInputData();
FirebaseModelInputs inputs = new FirebaseModelInputs.Builder()
    .add(input)  // add() as many input arrays as your model requires
    .build();
Task<FirebaseModelOutputs> result =
    firebaseInterpreter.run(inputs, inputOutputOptions)
        .addOnSuccessListener(
          new OnSuccessListener<FirebaseModelOutputs>() {
            @OverridepublicvoidonSuccess(FirebaseModelOutputs result){
              // ...
            }
          })
        .addOnFailureListener(
          new OnFailureListener() {
            @OverridepublicvoidonFailure(@NonNull Exception e){
              // Task failed with an exception// ...
            }
          });
float[][] output = result.<float[][]>getOutput(0);
float[] probabilities = output[0];

If you do not want to use Firebase for some reason, there is another way, call the tflite interpreter and feed it directly.

Add line to build / gradle

implementation'org.tensorflow:tensorflow-lite:+'

We create the interpreter and arrays

          Interpreter  tflite = new Interpreter(loadModelFile(getContext(), "model.tflite"));
// создаем массивы и заполняем inputs
          tflite.run(inputs,outputs)

The code in this case is much smaller, as you see.

That's all you need to use your neural network in android.

Useful links:

Off Docks for ML Kit
Tensorflow Lite

Also popular now: