在Java中调用Python函数:基于socket的方法

最近想写一个算法,考虑到执行效率和开发效率,准备使用Java来实现,不过在算法执行过程中我打算使用基于pytorch的神经网络,此时就涉及如何在Java语言中执行Python程序。

最简单的想法当然就是直接使用命令行调用:Runtime.getRuntime().exec("python xxx.py --yyy zzz"),但是这种方式每次调用python脚本都需要重启python解释器以及重新将pytorch神经网络加载进内存,如果调用频繁的话就会非常耗时,因此并不合适。

更好的方式是让两个程序通过网络来通信,python可以启动时加载好神经网络,之后就开一个socket服务等待请求,让java程序把参数通过socket发给python程序,然后python拿到参数后执行函数,再把返回值发送回java即可。python程序并不退出,因此就不需要重复神经网络的加载过程了。而且基于socket可以让java和python跑在不同的机器上。

下面给出一个简单的java做客户端,python做服务端的代码范例,首先是python服务端:

import socket


def py_func(x: str) -> str:
    actual_str = x.encode().decode('unicode_escape')  # parse escape character
    print(f'process 【{actual_str}】')
    ret_str = actual_str.encode('unicode_escape').decode() # recover escape character
    return f'Python process 【{ret_str}】 done.'


if __name__ == '__main__':
    sock = socket.socket()
    sock.bind(('127.0.0.1', 8080))
    sock.listen()
    sock.settimeout(0.5)

    print('server open')
    print('waiting connect', end='')

    quit = False
    while not quit:
        try:
            print('.', end='', flush=True)

            conn, _ = sock.accept()
            conn_f = conn.makefile('rw', encoding='utf8')

            print()

            while True:
                print('waiting receive...')
                recv_s = conn_f.readline()[:-1]

                if recv_s == '[close server]':
                    quit = True
                    break
                else:
                    print('get:', recv_s)

                try:
                    send_s = py_func(recv_s)
                except Exception as e:
                    send_s = f'{e.__class__.__name__}: {e}'

                print('send:', send_s)
                conn_f.write(send_s + '\n')
                conn_f.flush()

            conn.close()

        except TimeoutError:
            pass
        except (ConnectionResetError, ConnectionAbortedError) as e:
            print(f'{e.__class__.__name__}: {e}')
            conn.close()
            print('waiting connect', end='')

    sock.close()
    print('server close')

其中py_func就是我们需要python去执行的功能,在这里我给输入的字符串进行了转义字符解析,也就是如果输入的是r'a\nb',即四个字符:a、反斜线、n、b,那么actual_str会存储为'a\nb',即三个字符:a、换行符、b。之后actual_str使用完毕,会再把里面的换行符转换为r'\n'两个字符存入ret_str。

注意我给socket设置了timeout,这是为了能够响应ctrl+c的键盘中断。发送send_s时我加了一个换行符,这是为了java在之后能够使用nextLine获取响应信息,如果结尾没有换行符,在java中进行nextLine就会一直阻塞在读取响应中。

下面是java客户端:

import java.net.Socket;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.util.Scanner;

public class app {
    private static Scanner input_reader = new Scanner(System.in);

    public static void main(String[] args) throws IOException {

        var quit = false;
        while (!quit) {

            var clientSocket = new Socket("127.0.0.1", 8080);
            var reader = new Scanner(clientSocket.getInputStream(), StandardCharsets.UTF_8);
            var writer = new PrintWriter(clientSocket.getOutputStream(), true, StandardCharsets.UTF_8);

            while (true) {
                var input = input_reader.nextLine();

                if (input.equals("[close]")) {
                    quit = true;
                    break;
                }

                if (input.equals("[reconnect]")) {
                    break;
                }

                writer.println(input);

                System.out.println("get: " + reader.nextLine());
            }

            reader.close();
            clientSocket.close();
        }
    }
}

之后我们首先启动python服务端,然后运行java程序:java app,输入任何句子然后回车,就能获取到python的返回值了。

参考资料:

留下评论