最近想写一个算法,考虑到执行效率和开发效率,准备使用Java来实现,不过在算法执行过程中我打算使用基于pytorch的神经网络,此时就涉及如何在Java语言中执行Python程序。
最简单的想法当然就是直接使用命令行调用:Runtime.getRuntime().exec("python xxx.py --yyy zzz"),但是这种方式每次调用python脚本都需要重启python解释器以及重新将pytorch神经网络加载进内存,如果调用频繁的话就会非常耗时,因此并不合适。
更好的方式是让两个程序通过网络来通信,python可以启动时加载好神经网络,之后就开一个socket服务等待请求,让java程序把参数通过socket发给python程序,然后python拿到参数后执行函数,再把返回值发送回java即可。python程序并不退出,因此就不需要重复神经网络的加载过程了。而且基于socket可以让java和python跑在不同的机器上。
下面给出一个简单的java做客户端,python做服务端的代码范例,首先是python服务端:
import socket
def py_func(x: str) -> str:
actual_str = x.encode().decode('unicode_escape') # parse escape character
print(f'process 【{actual_str}】')
ret_str = actual_str.encode('unicode_escape').decode() # recover escape character
return f'Python process 【{ret_str}】 done.'
if __name__ == '__main__':
sock = socket.socket()
sock.bind(('127.0.0.1', 8080))
sock.listen()
sock.settimeout(0.5)
print('server open')
print('waiting connect', end='')
quit = False
while not quit:
try:
print('.', end='', flush=True)
conn, _ = sock.accept()
conn_f = conn.makefile('rw', encoding='utf8')
print()
while True:
print('waiting receive...')
recv_s = conn_f.readline()[:-1]
if recv_s == '[close server]':
quit = True
break
else:
print('get:', recv_s)
try:
send_s = py_func(recv_s)
except Exception as e:
send_s = f'{e.__class__.__name__}: {e}'
print('send:', send_s)
conn_f.write(send_s + '\n')
conn_f.flush()
conn.close()
except TimeoutError:
pass
except (ConnectionResetError, ConnectionAbortedError) as e:
print(f'{e.__class__.__name__}: {e}')
conn.close()
print('waiting connect', end='')
sock.close()
print('server close')
其中py_func就是我们需要python去执行的功能,在这里我给输入的字符串进行了转义字符解析,也就是如果输入的是r'a\nb',即四个字符:a、反斜线、n、b,那么actual_str会存储为'a\nb',即三个字符:a、换行符、b。之后actual_str使用完毕,会再把里面的换行符转换为r'\n'两个字符存入ret_str。
注意我给socket设置了timeout,这是为了能够响应ctrl+c的键盘中断。发送send_s时我加了一个换行符,这是为了java在之后能够使用nextLine获取响应信息,如果结尾没有换行符,在java中进行nextLine就会一直阻塞在读取响应中。
下面是java客户端:
import java.net.Socket;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.util.Scanner;
public class app {
private static Scanner input_reader = new Scanner(System.in);
public static void main(String[] args) throws IOException {
var quit = false;
while (!quit) {
var clientSocket = new Socket("127.0.0.1", 8080);
var reader = new Scanner(clientSocket.getInputStream(), StandardCharsets.UTF_8);
var writer = new PrintWriter(clientSocket.getOutputStream(), true, StandardCharsets.UTF_8);
while (true) {
var input = input_reader.nextLine();
if (input.equals("[close]")) {
quit = true;
break;
}
if (input.equals("[reconnect]")) {
break;
}
writer.println(input);
System.out.println("get: " + reader.nextLine());
}
reader.close();
clientSocket.close();
}
}
}
之后我们首先启动python服务端,然后运行java程序:java app,输入任何句子然后回车,就能获取到python的返回值了。
参考资料: