Как да се позове на обучен модел TensorFlow от програми на Java

Основният език, на който се създават и обучават модели за машинно обучение на TensorFlow, е Python. Въпреки това, много програми за сървър на корпоративни клиенти са написани на Java. Така че често ще се сблъскате със ситуации, в които трябва да се позовете на модела Tensorflow, който сте тренирали в Python от програма на Java.

Ако използвате Cloud ML Engine в платформата на Google Cloud, това не е проблем - в Cloud ML Engine прогнозите се правят чрез REST API разговор и така можете да направите това от всеки език на програмиране. Но какво ще стане, ако сте изтеглили модела TensorFlow и искате да извършвате прогнози офлайн?

Ето как можете да правите прогнози в Java с помощта на модели Tensorflow, които бяха обучени в Python.

Забележка: Екипът на Tensorflow вече започна да добавя връзки към Java. Вижте https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java за подробности. Опитайте това първо, и ако не работи за вас, ела тук ...

Изпишете модели на файлове в Python

Първото нещо, което трябва да направите, е да запазите модела TensorFlow в Python в два формата: (а) теглата, отклоненията и т.н. като файл „saver_def“ (б) самата графика като файл с протобуф. За да запазите здравия си разум, може да искате да запазите графиката като текст и като двоичен формат протобуф. Ще ви бъде полезно да прочетете текстовия формат, за да намерите имената, присвоени от TensorFlow на възли, на които не сте посочили изрично имена. Кодът за записване на тези три файла от Python:

# създайте обект Saver като нормален в Python, за да запазите променливите си
saver = tf.train.Saver (...)
# Използвайте saver_def, за да получите "вълшебните" низове за възстановяване
saver_def = saver.as_saver_def ()
печат saver_def.filename_tensor_name
печат saver_def.restore_op_name
# изпишете 3 файла
saver.save (sess, 'Training_model.sd')
tf.train.write_graph (sess.graph_def, '.', 'Training_model.proto', as_text = Грешно)
tf.train.write_graph (sess.graph_def, '.', 'Training_model.txt', as_text = Вярно)

В моя случай двата вълшебни низа, отпечатани от save_def, бяха save / Const: 0 и save / Resto_all - така че това ще видите в моя Java код. Променете ги, когато пишете своя Java код, ако вашият е различен.

.Sd файлът съдържа тегла, отклонения и т.н. (действителните стойности за променливите във вашата графика). Файлът .proto е двоичен файл, съдържащ вашата изчислителна графика и .txt съответната текстова версия.

Извикване на Tensorflow C ++ от Java

Въпреки че може би сте използвали Tensorflow в Python, за да захранвате данни към вашия модел и да ги обучавате, пакетът Tensorflow Python всъщност изисква C ++ реализация, за да извърши действителната работа. Следователно, можем да използваме Java Native Interface (JNI) за директно извикване на C ++ и да използваме C ++ за създаване на графиката и възстановяване на теглата и отклоненията от модела от Java.

Вместо да пишете на ръка всички обаждания на JNI, е възможно да използвате библиотека с отворен код, наречена JavaCpp, за да направите това. За да използвате JavaCpp, добавете тази зависимост към вашия Java Maven pom.xml:

<Зависимост>
   org.bytedeco.javacpp-настройки 
   tensorflow 
  <Версия> 0.9.0-1.2 

Ако използвате друга система за управление на сглобяването, добавете предварително зададени Javacpp за tensorflow и всички негови зависимости към класния път на приложението ви.

Създайте модел в Java

Във вашия Java код прочетете прото файла, за да създадете определение на Graph както следва (импортирането е пропуснато за яснота):

заключителна сесия = нова сесия (нови SessionOptions ());
GraphDef def = нов GraphDef ();
tensorflow.ReadBinaryProto (Env.Default (),
                           "somedir / Training_model.proto", def);
Състояние s = сесия.Създайте (def);
ако (! s.ok ()) {
    хвърлете нова RuntimeException (s.error_message (). getString ());
}

След това възстановете теглата и отклоненията от записания модел на файла с помощта на Session :: Run (). Обърнете внимание как се използват вълшебните струни от saver_def.

// Възстанови
Tensor fn = нов Tensor (tensorflow.DT_STRING, нов TensorShape (1));
StringArray a = fn.createStringArray ();
a.position (0) .put ( "somedir / trained_model.sd");
s = session.Run (нов StringTensorPairVector (нов String [] {“запазване / Const: 0”}, нов Tensor [] {fn}), нов StringVector (), нов StringVector (“save / Resto_all”), нов TensorVector ( ));
ако (! s.ok ()) {
   хвърлете нова RuntimeException (s.error_message (). getString ());
}

Правене на прогнози в Java

В този момент вашият модел е готов. Сега можете да го използвате, за да правите прогнози. Това е подобно на това как го правите в Python - трябва да предадете стойности за всичките си заместители и да оцените изходния възел. Разликата е, че трябва да знаете действителните имена на заместителя и изходните възли. Ако не сте задали тези възли уникални имена в Python, Tensorflow им е присвоил имена. Можете да разберете какви са те, като погледнете файла, който е изписан. Или можете да се върнете към своя Python код и да присвоите имената на ключовите възли, които помните. В моя случай, входният заместител на име се нарича Placeholder; запазващият възел на отпадащия възел се нарича Placeholder_2, а изходният възел се нарича Sigmoid. Ще видите тези препратки в обаждането на сесия :: Run () по-долу.

В моя случай невронната мрежа използва 5 променливи променливи. Ако приемем, че имам масив от входове, които са предсказателите на моя модел на невронната мрежа и искам да направя прогнозата за 2 комплекта такива входове, моят вход е матрица 2x5. Моят NN има само един изход, така че за 2 комплекта входове, изходният тензор е матрица 2x1. На отпадащия възел се дава кодиран вход от 1.0 (при прогнозиране запазваме всички възли - вероятността за отпадане е само за обучение). И така, имам:

// опитайте се да прогнозирате за два (2) набора от входове.
Тензорни входове = нов тензор (
         tensorflow.DT_FLOAT, нова TensorShape (2,5));
FloatBuffer x = inputs.createBuffer ();
x.put (нов поплавък [] {- ​​6.0f, 22.0f, 383.0f, 27.781754111198122f, -6.5f});
x.put (нов поплавък [] {66.0f, 22.0f, 2422.0f, 45.72160947712418f, 0.4f});
Tensor Keepall = нов Tensor (
        tensorflow.DT_FLOAT, нова TensorShape (2,1));
((FloatBuffer) Keepall.createBuffer ()). Put (new float [] {1f, 1f});
Изходи TensorVector = нов TensorVector ();
// да предвижда всеки път, да предава стойности за заместители
outputs.resize (0);
s = session.Run (нов StringTensorPairVector (нов String [] {„Заместител на място“, „Placeholder_2“}, нов Tensor [] {входове, поддържане}),
 нов StringVector (“Sigmoid”), нов StringVector (), изходи);
ако (! s.ok ()) {
   хвърлете нова RuntimeException (s.error_message (). getString ());
}
// по този начин връщате прогнозната стойност от изходите
FloatBuffer output = outputs.get (0) .createBuffer ();
за (int k = 0; k 

Това е всичко - сега използвате Java, за да изпълните прогнозите си. Има няколко стъпки, но това трябва да се очаква, когато човек смесва 3 програмни езика (Python, C ++ и Java). Но важното е, че може да се направи и че е сравнително прям.

Разбира се, ако това не се възползва от хардуерното ускорение и разпространение. Ако искате да правите прогнози с много висока скорост в реално време, трябва да помислите да използвате Cloud ML Engine.