diff --git a/ihmc-high-level-behaviors/build.gradle.kts b/ihmc-high-level-behaviors/build.gradle.kts index 4b0c65e3e1e..54c9545d6f7 100644 --- a/ihmc-high-level-behaviors/build.gradle.kts +++ b/ihmc-high-level-behaviors/build.gradle.kts @@ -19,6 +19,7 @@ mainDependencies { } api("us.ihmc:promp-java:1.0.2") api("us.ihmc:llama.cpp-javacpp:b4829-1") + api("org.msgpack:msgpack-core:0.9.10") // openpi client } libgdxDependencies { diff --git a/ihmc-high-level-behaviors/src/main/java/us/ihmc/openpi/OpenpiClient.java b/ihmc-high-level-behaviors/src/main/java/us/ihmc/openpi/OpenpiClient.java new file mode 100644 index 00000000000..4e9dbaed4a5 --- /dev/null +++ b/ihmc-high-level-behaviors/src/main/java/us/ihmc/openpi/OpenpiClient.java @@ -0,0 +1,229 @@ +package us.ihmc.openpi; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpObjectAggregator; +import org.msgpack.core.MessageBufferPacker; +import org.msgpack.core.MessagePack; +import org.msgpack.core.MessageUnpacker; +import us.ihmc.commons.exception.DefaultExceptionHandler; +import us.ihmc.log.LogTools; +import us.ihmc.robotics.robotSide.RobotSide; +import us.ihmc.robotics.robotSide.SideDependentList; + +import java.net.URI; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +public class OpenpiClient +{ + private final String host; + private final int port = 8000; + private final int stateSize; + private EventLoopGroup group; + private Channel channel; + private OpenpiNettyWebSocketHandler handler; + private final MessageBufferPacker packer = MessagePack.newDefaultBufferPacker(); + private final ByteBuffer state; + private final SideDependentList images = new SideDependentList<>(ByteBuffer.allocate(3 * 224 * 224), // uint8 rgb Channel - Height - Width + ByteBuffer.allocate(3 * 224 * 224)); + private final ByteBuffer actions; + private float policyTimingMs; + private float serverTimingMs; + + public OpenpiClient(String host, int stateSize) + { + this.host = host; + this.stateSize = stateSize; + + state = ByteBuffer.allocate(stateSize * Float.BYTES); + state.order(ByteOrder.nativeOrder()); + for (RobotSide side : RobotSide.values) + images.get(side).order(ByteOrder.nativeOrder()); + actions = ByteBuffer.allocate(50 * stateSize * Double.BYTES); + actions.order(ByteOrder.nativeOrder()); + } + + public CompletableFuture request() + { + try + { + if (channel == null || !channel.isActive()) // reconnect if necessary + { + destroy(); + + try + { + group = new NioEventLoopGroup(); + Bootstrap bootstrap = new Bootstrap(); + handler = new OpenpiNettyWebSocketHandler(new URI("ws://" + host + ":" + port)); + bootstrap.group(group).channel(NioSocketChannel.class).handler(new ChannelInitializer() + { + @Override + protected void initChannel(SocketChannel ch) + { + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast(new HttpClientCodec()); + pipeline.addLast(new HttpObjectAggregator(65536)); + pipeline.addLast(handler); + } + }); + channel = bootstrap.connect(host, port).sync().channel(); + handler.handshakeFuture().sync(); + handler.awaitFirstMessage(10, TimeUnit.SECONDS); + } + catch (Exception exception) + { + return null; + } + } + + packer.clear(); + packer.packMapHeader(4); + for (RobotSide side : RobotSide.values) + { + packer.packString("cam_zed_%s".formatted(side.getLowerCaseName())).packMapHeader(4); + packer.packString("__ndarray__").packBoolean(true); + byte[] imgData = images.get(side).array(); + packer.packString("data").packBinaryHeader(imgData.length).writePayload(imgData); + packer.packString("dtype").packString("uint8"); + packer.packString("shape").packArrayHeader(3); + packer.packInt(3); // channels, rgb + packer.packInt(224); // height + packer.packInt(224); // width + } + packer.packString("state").packMapHeader(4); + packer.packString("__ndarray__").packBoolean(true); + packer.packString("data").packBinaryHeader(state.array().length).writePayload(state.array()); + packer.packString("dtype").packString("float32"); + packer.packString("shape").packArrayHeader(1).packInt(stateSize); + packer.packString("prompt").packString("touch door handle"); // TODO + + return handler.sendAndAwaitResponse(packer.toByteArray()); + } + catch (Exception e) + { + throw new RuntimeException("Request failed", e); + } + } + + public void unpack(CompletableFuture response) + { + try + { + MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(response.get()); + unpacker.unpackMapHeader(); // 3 + unpacker.unpackString(); // actions + unpacker.unpackMapHeader(); // 4 + unpacker.unpackString(); // __ndarray__ + unpacker.unpackBoolean(); // true + unpacker.unpackString(); // data + unpacker.unpackBinaryHeader(); // 50 * STATE_SIZE * 8 + unpacker.readPayload(actions.array()); + unpacker.unpackString(); // dtype + unpacker.unpackString(); // getImages() + { + return images; + } + + public ByteBuffer getActionChunk() + { + return actions; + } + + public float getPolicyTimingMs() + { + return policyTimingMs; + } + + public float getServerTimingMs() + { + return serverTimingMs; + } + + public void destroy() + { + if (channel != null) + channel.close(); + if (group != null) + group.shutdownGracefully(); + + channel = null; + group = null; + } + + public boolean hasBeenStarted() + { + return group != null; + } + + public String getHost() + { + return host; + } + + public int getPort() + { + return port; + } + + public static void main(String[] args) + { + int stateSize = 3; + OpenpiClient client = new OpenpiClient("10.6.192.65", 3); + + for (RobotSide side : RobotSide.values) + { + byte[] imgData = client.getImages().get(side).array(); + for (int i = 0; i < imgData.length; i++) + imgData[i] = (byte) (i % 256); + } + + ByteBuffer stateData = client.getState(); + for (int i = 0; i < stateSize; i++) + stateData.putFloat((float) (i % stateSize)); + + CompletableFuture request = client.request(); + client.unpack(request); + + LogTools.info("Action chunk: %d".formatted(client.getActionChunk().limit())); + + client.destroy(); + } +} + diff --git a/ihmc-high-level-behaviors/src/main/java/us/ihmc/openpi/OpenpiNettyWebSocketHandler.java b/ihmc-high-level-behaviors/src/main/java/us/ihmc/openpi/OpenpiNettyWebSocketHandler.java new file mode 100644 index 00000000000..aaa2640b749 --- /dev/null +++ b/ihmc-high-level-behaviors/src/main/java/us/ihmc/openpi/OpenpiNettyWebSocketHandler.java @@ -0,0 +1,170 @@ +package us.ihmc.openpi; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker; +import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory; +import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException; +import io.netty.handler.codec.http.websocketx.WebSocketVersion; +import us.ihmc.log.LogTools; + +import java.net.URI; +import java.util.LinkedList; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +class OpenpiNettyWebSocketHandler extends SimpleChannelInboundHandler +{ + private final URI uri; + private final WebSocketClientHandshaker handshaker; + private ChannelPromise handshakeFuture; + + private final Queue> pendingResponses = new LinkedList<>(); + private final CountDownLatch firstMessageLatch = new CountDownLatch(1); + private byte[] firstMessage; + + public OpenpiNettyWebSocketHandler(URI uri) + { + this.handshaker = WebSocketClientHandshakerFactory.newHandshaker(uri, WebSocketVersion.V13, null, false, new DefaultHttpHeaders()); + this.uri = uri; + } + + public ChannelFuture handshakeFuture() + { + return handshakeFuture; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) + { + handshakeFuture = ctx.newPromise(); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) + { + handshaker.handshake(ctx.channel()); + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) + { + Channel ch = ctx.channel(); + + if (!handshaker.isHandshakeComplete()) + { + try + { + handshaker.finishHandshake(ch, (FullHttpResponse) msg); + LogTools.info("Connected to openpi server at " + uri); + handshakeFuture.setSuccess(); + } + catch (WebSocketHandshakeException e) + { + LogTools.error("WebSocket handshake failed"); + handshakeFuture.setFailure(e); + } + return; + } + + if (msg instanceof FullHttpResponse response) + { + throw new IllegalStateException("Unexpected FullHttpResponse (getStatus=" + response.status() + ", content=" + response.content().toString() + ')'); + } + + if (msg instanceof TextWebSocketFrame textFrame) + { + // Server error path - convert to exception like Python client + String errorMsg = textFrame.text(); + LogTools.error("Error from server: " + errorMsg); + + // Complete any pending futures with exception + CompletableFuture pending = pendingResponses.poll(); + if (pending != null) + { + pending.completeExceptionally(new RuntimeException("Server error: " + errorMsg)); + } + } + else if (msg instanceof BinaryWebSocketFrame binaryFrame) + { + ByteBuf content = binaryFrame.content(); + byte[] data = new byte[content.readableBytes()]; + content.readBytes(data); + + // Handle first message (server metadata) + if (firstMessage == null) + { + firstMessage = data; + firstMessageLatch.countDown(); + } + else + { + // Handle response to inference request + CompletableFuture pending = pendingResponses.poll(); + if (pending != null) + { + pending.complete(data); + } + } + } + else if (msg instanceof CloseWebSocketFrame) + { + LogTools.warn("WebSocket connection closed"); + ch.close(); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + { + cause.printStackTrace(); + if (!handshakeFuture.isDone()) + { + handshakeFuture.setFailure(cause); + } + ctx.close(); + } + + public CompletableFuture sendAndAwaitResponse(byte[] data) + { + CompletableFuture future = new CompletableFuture<>(); + pendingResponses.offer(future); + + ByteBuf buffer = Unpooled.wrappedBuffer(data); + BinaryWebSocketFrame frame = new BinaryWebSocketFrame(buffer); + + handshakeFuture.channel().writeAndFlush(frame).addListener(channelFuture -> + { + if (!channelFuture.isSuccess()) + { + future.completeExceptionally(channelFuture.cause()); + } + }); + + return future; + } + + public byte[] awaitFirstMessage(long timeout, TimeUnit unit) throws Exception + { + if (firstMessageLatch.await(timeout, unit)) + { + return firstMessage; + } + else + { + throw new RuntimeException("Timeout waiting for server metadata"); + } + } +}