diff --git a/.bazelrc b/.bazelrc index abf05fc6d7..c10fb589bc 100644 --- a/.bazelrc +++ b/.bazelrc @@ -50,6 +50,8 @@ build --features=per_object_debug_info # We already have absl in the build, define absl=1 to tell googletest to use absl for backtrace. build --define absl=1 +build:rdma --define BRPC_WITH_RDMA=true + # For UT. build:test --define BRPC_BUILD_FOR_UNITTEST=true # Hide libunwind's `_Unwind_*` symbols so they don't preempt libgcc_s at diff --git a/.github/actions/install-all-dependencies/action.yml b/.github/actions/install-all-dependencies/action.yml index 86d2884b97..5c1f673ff7 100644 --- a/.github/actions/install-all-dependencies/action.yml +++ b/.github/actions/install-all-dependencies/action.yml @@ -2,7 +2,7 @@ runs: using: "composite" steps: - uses: ./.github/actions/install-essential-dependencies - - run: sudo apt-get update && sudo apt-get install -y libunwind-dev libgoogle-glog-dev automake bison flex libboost-all-dev libevent-dev libtool pkg-config libibverbs1 libibverbs-dev + - run: sudo apt-get update && sudo apt-get install -y libunwind-dev libgoogle-glog-dev automake bison flex libboost-all-dev libevent-dev libtool pkg-config libibverbs-dev shell: bash - run: | wget https://archive.apache.org/dist/thrift/0.11.0/thrift-0.11.0.tar.gz && tar -xf thrift-0.11.0.tar.gz && cd thrift-0.11.0/ diff --git a/.github/workflows/ci-linux.yml b/.github/workflows/ci-linux.yml index 8a36af6024..a334b29126 100644 --- a/.github/workflows/ci-linux.yml +++ b/.github/workflows/ci-linux.yml @@ -29,7 +29,9 @@ jobs: - name: gcc with all options uses: ./.github/actions/compile-with-make with: - options: --headers=/usr/include --libs=/usr/lib /usr/lib64 --cc=gcc --cxx=g++ --werror --with-thrift --with-glog --with-rdma --with-debug-bthread-sche-safety --with-debug-lock --with-bthread-tracer --with-asan + options: --headers=/usr/include --libs=/usr/lib /usr/lib64 --cc=gcc --cxx=g++ --werror \ + --with-thrift --with-glog --with-rdma --with-debug-bthread-sche-safety \ + --with-debug-lock --with-bthread-tracer --with-asan - name: clang with default options uses: ./.github/actions/compile-with-make @@ -39,7 +41,9 @@ jobs: - name: clang with all options uses: ./.github/actions/compile-with-make with: - options: --headers=/usr/include --libs=/usr/lib /usr/lib64 --cc=clang --cxx=clang++ --werror --with-thrift --with-glog --with-rdma --with-debug-bthread-sche-safety --with-debug-lock --with-bthread-tracer --with-asan + options: --headers=/usr/include --libs=/usr/lib /usr/lib64 --cc=clang --cxx=clang++ --werror \ + --with-thrift --with-glog --with-rdma --with-debug-bthread-sche-safety \ + --with-debug-lock --with-bthread-tracer --with-asan compile-with-cmake: runs-on: ubuntu-22.04 @@ -57,7 +61,9 @@ jobs: run: | export CC=gcc && export CXX=g++ mkdir gcc_build_all && cd gcc_build_all - cmake -DWITH_MESALINK=OFF -DWITH_GLOG=ON -DWITH_THRIFT=ON -DWITH_RDMA=ON -DWITH_DEBUG_BTHREAD_SCHE_SAFETY=ON -DWITH_DEBUG_LOCK=ON -DWITH_BTHREAD_TRACER=ON -DWITH_ASAN=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5 .. + cmake -DWITH_MESALINK=OFF -DWITH_GLOG=ON -DWITH_THRIFT=ON -DWITH_RDMA=ON \ + -DWITH_DEBUG_BTHREAD_SCHE_SAFETY=ON -DWITH_DEBUG_LOCK=ON -DWITH_BTHREAD_TRACER=ON \ + -DWITH_ASAN=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5 .. make -j ${{env.proc_num}} && make clean - name: clang with default options @@ -70,7 +76,9 @@ jobs: run: | export CC=clang && export CXX=clang++ mkdir clang_build_all && cd clang_build_all - cmake -DWITH_MESALINK=OFF -DWITH_GLOG=ON -DWITH_THRIFT=ON -DWITH_RDMA=ON -DWITH_DEBUG_BTHREAD_SCHE_SAFETY=ON -DWITH_DEBUG_LOCK=ON -DWITH_BTHREAD_TRACER=ON -DWITH_ASAN=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5 .. + cmake -DWITH_MESALINK=OFF -DWITH_GLOG=ON -DWITH_THRIFT=ON -DWITH_RDMA=ON \ + -DWITH_DEBUG_BTHREAD_SCHE_SAFETY=ON -DWITH_DEBUG_LOCK=ON -DWITH_BTHREAD_TRACER=ON \ + -DWITH_ASAN=ON -DCMAKE_POLICY_VERSION_MINIMUM=3.5 .. make -j ${{env.proc_num}} && make clean gcc-compile-with-make-protobuf: @@ -160,6 +168,7 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 + - run: sudo apt-get update && sudo apt-get install -y libibverbs-dev - run: | bazel test --test_output=streamed \ --action_env=CC=clang \ @@ -229,6 +238,7 @@ jobs: USE_BAZEL_VERSION: "8.3.1" steps: - uses: actions/checkout@v2 + - run: sudo apt-get update && sudo apt-get install -y libibverbs-dev - name: Override protobuf version for testing run: | sed -i -E "s/(bazel_dep\(name = ['\"]protobuf['\"], version = ['\"])[^'\"]+/\1${TEST_PROTOBUF_VERSION}/" MODULE.bazel @@ -237,7 +247,6 @@ jobs: grep -qE "bazel_dep\(name = ['\"]protobuf['\"], version = ['\"]${TEST_PROTOBUF_VERSION}['\"]" MODULE.bazel \ || { echo "ERROR: failed to override protobuf version in MODULE.bazel to ${TEST_PROTOBUF_VERSION}"; exit 1; } - run: | - bazel test --action_env=CC=clang \ + bazel test --action_env=CC=clang --config=rdma \ --define with_babylon_counter=true \ - --define with_babylon_counter=true \ - //test:brpc_unittests + //test/... --test_arg=--gtest_filter=-RdmaRpcTest.* diff --git a/BUILD.bazel b/BUILD.bazel index 22cb508548..b51ee0f6b0 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -506,6 +506,7 @@ filegroup( srcs = glob([ "src/brpc/*.proto", "src/brpc/policy/*.proto", + "src/brpc/rdma/*.proto", ]), visibility = ["//visibility:public"], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5e74007b66..a3ebb855cf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -555,7 +555,8 @@ set(PROTO_FILES idl_options.proto brpc/policy/mongo.proto brpc/trackme.proto brpc/streaming_rpc_meta.proto - brpc/proto_base.proto) + brpc/proto_base.proto + brpc/rdma/rdma_handshake.proto) file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/output/include/brpc) set(PROTOC_FLAGS ${PROTOC_FLAGS} -I${PROTOBUF_INCLUDE_DIR}) compile_proto(PROTO_HDRS PROTO_SRCS ${PROJECT_BINARY_DIR} diff --git a/src/brpc/rdma/rdma_endpoint.cpp b/src/brpc/rdma/rdma_endpoint.cpp index f09d723ca1..658c7a2fcc 100644 --- a/src/brpc/rdma/rdma_endpoint.cpp +++ b/src/brpc/rdma/rdma_endpoint.cpp @@ -31,6 +31,7 @@ #include "brpc/rdma/rdma_helper.h" #include "brpc/rdma/rdma_endpoint.h" #include "brpc/rdma_transport.h" +#include "brpc/rdma/rdma_handshake.h" DECLARE_int32(task_group_ntags); @@ -70,84 +71,30 @@ static const size_t IOBUF_BLOCK_HEADER_LEN = 32; // implementation-dependent // DO NOT change this value unless you know the safe value!!! // This is the number of reserved WRs in SQ/RQ for pure ACK. -static const size_t RESERVED_WR_NUM = 3; - -// magic string RDMA (4B) -// message length (2B) -// hello version (2B) -// impl version (2B): 0 means should use tcp -// block size (4B) -// sq size (2B) -// rq size (2B) -// GID (16B) -// QP number (4B) -static const char* MAGIC_STR = "RDMA"; -static const size_t MAGIC_STR_LEN = 4; -static const size_t HELLO_MSG_LEN_MIN = 40; -// static const size_t HELLO_MSG_LEN_MAX = 4096; -static const size_t ACK_MSG_LEN = 4; -static uint16_t g_rdma_hello_msg_len = 40; // In Byte -static uint16_t g_rdma_hello_version = 2; -static uint16_t g_rdma_impl_version = 1; -static uint32_t g_rdma_recv_block_size = 0; +extern const size_t RESERVED_WR_NUM = 3; + +// The local recv block size, set during GlobalInitialize. +uint32_t g_rdma_recv_block_size = 0; // static const uint32_t MAX_INLINE_DATA = 64; static const uint8_t MAX_HOP_LIMIT = 16; static const uint8_t TIMEOUT = 14; static const uint8_t RETRY_CNT = 7; -static const uint16_t MIN_QP_SIZE = 16; +extern const uint16_t MIN_QP_SIZE = 16; static const uint16_t MAX_QP_SIZE = 4096; -static const uint16_t MIN_BLOCK_SIZE = 1024; -static const uint32_t ACK_MSG_RDMA_OK = 0x1; +extern const uint16_t MIN_BLOCK_SIZE = 1024; + +// ACK message wire format (shared by all protocol versions): a single +// 4B big-endian flags word; bit 0 (HELLO_ACK_RDMA_OK) indicates the +// sender wants to use RDMA. The state machines in +// ProcessHandshakeAt{Client,Server} inline the corresponding 4B +// send/recv directly using ReadFromFd / WriteToFd. +static const size_t HELLO_ACK_LEN = 4; +static const uint32_t HELLO_ACK_RDMA_OK = 0x1; static butil::Mutex* g_rdma_resource_mutex = NULL; static RdmaResource* g_rdma_resource_list = NULL; -struct HelloMessage { - void Serialize(void* data) const; - void Deserialize(void* data); - - uint16_t msg_len; - uint16_t hello_ver; - uint16_t impl_ver; - uint32_t block_size; - uint16_t sq_size; - uint16_t rq_size; - uint16_t lid; - ibv_gid gid; - uint32_t qp_num; -}; - -void HelloMessage::Serialize(void* data) const { - uint16_t* current_pos = (uint16_t*)data; - *(current_pos++) = butil::HostToNet16(msg_len); - *(current_pos++) = butil::HostToNet16(hello_ver); - *(current_pos++) = butil::HostToNet16(impl_ver); - uint32_t* block_size_pos = (uint32_t*)current_pos; - *block_size_pos = butil::HostToNet32(block_size); - current_pos += 2; // move forward 4 Bytes - *(current_pos++) = butil::HostToNet16(sq_size); - *(current_pos++) = butil::HostToNet16(rq_size); - *(current_pos++) = butil::HostToNet16(lid); - memcpy(current_pos, gid.raw, 16); - uint32_t* qp_num_pos = (uint32_t*)((char*)current_pos + 16); - *qp_num_pos = butil::HostToNet32(qp_num); -} - -void HelloMessage::Deserialize(void* data) { - uint16_t* current_pos = (uint16_t*)data; - msg_len = butil::NetToHost16(*current_pos++); - hello_ver = butil::NetToHost16(*current_pos++); - impl_ver = butil::NetToHost16(*current_pos++); - block_size = butil::NetToHost32(*(uint32_t*)current_pos); - current_pos += 2; // move forward 4 Bytes - sq_size = butil::NetToHost16(*current_pos++); - rq_size = butil::NetToHost16(*current_pos++); - lid = butil::NetToHost16(*current_pos++); - memcpy(gid.raw, current_pos, 16); - qp_num = butil::NetToHost32(*(uint32_t*)((char*)current_pos + 16)); -} - RdmaResource::~RdmaResource() { if (NULL != qp) { IbvDestroyQp(qp); @@ -169,6 +116,7 @@ RdmaResource::~RdmaResource() { RdmaEndpoint::RdmaEndpoint(Socket* s) : _socket(s) , _state(UNINIT) + , _handshake_version(0) , _resource(NULL) , _send_cq_events(0) , _recv_cq_events(0) @@ -348,31 +296,34 @@ void RdmaEndpoint::OnNewDataFromTcp(Socket* m) { } } -bool HelloNegotiationValid(HelloMessage& msg) { - if (msg.hello_ver == g_rdma_hello_version && - msg.impl_ver == g_rdma_impl_version && - msg.block_size >= MIN_BLOCK_SIZE && - msg.sq_size >= MIN_QP_SIZE && - msg.rq_size >= MIN_QP_SIZE) { - // This can be modified for future compatibility - return true; - } - return false; -} - static const int WAIT_TIMEOUT_MS = 50; -int RdmaEndpoint::ReadFromFd(void* data, size_t len) { - CHECK(data != NULL); - int nr = 0; +// Drive an EAGAIN-aware read loop to completion (exactly `len` bytes). +// `read_once(offset, remaining)` performs ONE underlying read attempt: +// - returns > 0 : number of bytes consumed (added to running total); +// - returns = 0 : end-of-stream (the loop fails with EEOF); +// - returns < 0 : errno set; EAGAIN is handled here via butex_wait, +// any other errno bubbles up. +// `offset` is bytes already received in THIS call (initially 0); the +// callable uses it to choose the next write target (e.g. `(char*)buf +// + offset`). Callables that don't need offset (e.g. IOPortal append) +// can ignore it. +// +// Centralizes the EAGAIN/butex/EOF loop so the two ReadFromFd +// overloads below stay one-liners; any future read source (memory- +// mapped, scatter-vector, etc.) can plug in by passing its own +// `read_once`. +template +static int ReadFromFdLoop(butil::atomic* read_butex, + size_t len, ReadOnce&& read_once) { size_t received = 0; - do { - const int expected_val = _read_butex->load(butil::memory_order_acquire); + while (received < len) { + const int expected_val = read_butex->load(butil::memory_order_acquire); const timespec duetime = butil::milliseconds_from_now(WAIT_TIMEOUT_MS); - nr = read(_socket->fd(), (uint8_t*)data + received, len - received); + ssize_t nr = read_once(received, len - received); if (nr < 0) { if (errno == EAGAIN) { - if (bthread::butex_wait(_read_butex, expected_val, &duetime) < 0) { + if (bthread::butex_wait(read_butex, expected_val, &duetime) < 0) { if (errno != EWOULDBLOCK && errno != ETIMEDOUT) { return -1; } @@ -386,34 +337,89 @@ int RdmaEndpoint::ReadFromFd(void* data, size_t len) { } else { received += nr; } - } while (received < len); + } return 0; } -int RdmaEndpoint::WriteToFd(void* data, size_t len) { +int RdmaEndpoint::ReadFromFd(void* data, size_t len) { + CHECK(data != NULL); + const int fd = _socket->fd(); + return ReadFromFdLoop(_read_butex, len, + [data, fd](size_t offset, size_t remaining) { + return read(fd, (uint8_t*)data + offset, remaining); + }); +} + +int RdmaEndpoint::ReadFromFd(butil::IOPortal* data, size_t len) { CHECK(data != NULL); - int nw = 0; + const int fd = _socket->fd(); + return ReadFromFdLoop(_read_butex, len, + [data, fd](size_t /*offset*/, size_t remaining) { + return data->append_from_file_descriptor(fd, remaining); + }); +} + +// Drive an EAGAIN-aware write loop to completion (exactly `len` bytes). +// +// `write_once(offset, remaining)` performs ONE underlying write attempt: +// - returns >= 0 : number of bytes consumed (added to running total); +// - returns < 0 : errno set; EAGAIN triggers `wait_writable(duetime)`, +// any other errno bubbles up. +// `offset` is bytes already written in THIS call (initially 0); the +// callable uses it to choose the next read source (e.g. `(char*)buf +// + offset`). Callables that drain a self-tracking sink (e.g. +// IOBuf::cut_into_file_descriptor) can ignore both args. +// +// `wait_writable(duetime)` is invoked on EAGAIN to park until the fd +// becomes writable again. It returns 0 on wake-up (or ETIMEDOUT), +// non-zero on hard failure. +template +static int WriteToFdLoop(size_t len, WriteOnce&& write_once, WaitWritable&& wait_writable) { size_t written = 0; - do { + while (written < len) { const timespec duetime = butil::milliseconds_from_now(WAIT_TIMEOUT_MS); - nw = write(_socket->fd(), (uint8_t*)data + written, len - written); - if (nw < 0) { - if (errno == EAGAIN) { - if (_socket->WaitEpollOut(_socket->fd(), true, &duetime) < 0) { - if (errno != ETIMEDOUT) { - return -1; - } - } - } else { - return -1; - } - } else { + ssize_t nw = write_once(written, len - written); + if (nw >= 0) { written += nw; + continue; + } + + if (errno != EAGAIN) { + return -1; } - } while (written < len); + if (!wait_writable(&duetime)) { + return -1; + } + } return 0; } +int RdmaEndpoint::WriteToFd(void* data, size_t len) { + CHECK(data != NULL); + Socket* s = _socket; + const int fd = s->fd(); + return WriteToFdLoop(len, + [data, fd](size_t offset, size_t remaining) { + return write(fd, (uint8_t*)data + offset, remaining); + }, + [s, fd](const timespec* duetime) { + return s->WaitEpollOut(fd, true, duetime) == 0 || errno == ETIMEDOUT; + }); +} + +int RdmaEndpoint::WriteToFd(butil::IOBuf* data) { + CHECK(data != NULL); + Socket* s = _socket; + const int fd = s->fd(); + return WriteToFdLoop(data->size(), + [data, fd](size_t /*offset*/, size_t /*remaining*/) { + return data->cut_into_file_descriptor(fd); + }, + [s, fd](const timespec* duetime) { + return s->WaitEpollOut(fd, true, duetime) == 0 || errno == ETIMEDOUT; + }); +} + inline void RdmaEndpoint::TryReadOnTcp() { if (_socket->_nevent.fetch_add(1, butil::memory_order_acq_rel) == 0) { if (_state == FALLBACK_TCP) { @@ -424,19 +430,52 @@ inline void RdmaEndpoint::TryReadOnTcp() { } } +void RdmaEndpoint::ApplyRemoteHello(const ParsedHello& remote) { + _remote_recv_block_size = remote.block_size; + _local_window_capacity = + std::min(_sq_size, remote.rq_size) - RESERVED_WR_NUM; + _remote_window_capacity = + std::min(_rq_size, remote.sq_size) - RESERVED_WR_NUM; + _sq_imm_window_size = RESERVED_WR_NUM; + _remote_rq_window_size.store( + _local_window_capacity, butil::memory_order_relaxed); + _sq_window_size.store( + _local_window_capacity, butil::memory_order_relaxed); +} + +// Client-side handshake entry: the state machine. +// +// C_ALLOC_QPCQ +// | +// v +// C_HELLO_SEND (hs->SendLocalHello) +// | +// v +// C_HELLO_WAIT (hs->ReceiveAndParseRemoteHello) +// | +// v +// [negotiation: ApplyRemoteHello + C_BRINGUP_QP] +// | +// v +// C_ACK_SEND +// | +// v +// ESTABLISHED / FALLBACK_TCP void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { - RdmaEndpoint* ep = static_cast(arg); + auto ep = static_cast(arg); SocketUniquePtr s(ep->_socket); RdmaConnect::RunGuard rg((RdmaConnect*)s->_app_connect.get()); + auto rdma_transport = static_cast(s->_transport.get()); - LOG_IF(INFO, FLAGS_rdma_trace_verbose) - << "Start handshake on " << s->_local_side; + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Start handshake on " << s->description(); - uint8_t data[g_rdma_hello_msg_len]; + std::unique_ptr handshake = CreateClientHandshake(ep); + CHECK(handshake != NULL); + ep->_handshake_version = handshake->ProtocolVersion(); - // First initialize CQ and QP resources + // First initialize CQ and QP resources. ep->_state = C_ALLOC_QPCQ; - auto* rdma_transport = static_cast(s->_transport.get()); if (ep->AllocateResources() < 0) { LOG(WARNING) << "Fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; @@ -446,94 +485,40 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { // Send hello message to server ep->_state = C_HELLO_SEND; - HelloMessage local_msg; - local_msg.msg_len = g_rdma_hello_msg_len; - local_msg.hello_ver = g_rdma_hello_version; - local_msg.impl_ver = g_rdma_impl_version; - local_msg.block_size = g_rdma_recv_block_size; - local_msg.sq_size = ep->_sq_size; - local_msg.rq_size = ep->_rq_size; - local_msg.lid = GetRdmaLid(); - local_msg.gid = GetRdmaGid(); - if (BAIDU_LIKELY(ep->_resource)) { - local_msg.qp_num = ep->_resource->qp->qp_num; - } else { - // Only happens in UT - local_msg.qp_num = 0; - } - memcpy(data, MAGIC_STR, 4); - local_msg.Serialize((char*)data + 4); - if (ep->WriteToFd(data, g_rdma_hello_msg_len) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to send hello message to server:" << s->description(); + if (handshake->SendLocalHello() < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to send hello message to server:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - // Check magic str + // Receive and parse remote hello. ep->_state = C_HELLO_WAIT; - if (ep->ReadFromFd(data, MAGIC_STR_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to get hello message from server:" << s->description(); + ParsedHello remote{}; + bool negotiated = false; + if (handshake->ReceiveAndParseRemoteHello(&remote, &negotiated) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to receive hello from server:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); - ep->_state = FAILED; - return NULL; - } - if (memcmp(data, MAGIC_STR, MAGIC_STR_LEN) != 0) { - LOG(WARNING) << "Read unexpected data during handshake:" << s->description(); - s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(EPROTO)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - // Read hello message from server - if (ep->ReadFromFd(data, HELLO_MSG_LEN_MIN - MAGIC_STR_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to get Hello Message from server:" << s->description(); - s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); - ep->_state = FAILED; - return NULL; - } - HelloMessage remote_msg; - remote_msg.Deserialize(data); - if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) { - LOG(WARNING) << "Fail to parse Hello Message length from server:" - << s->description(); - s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(EPROTO)); - ep->_state = FAILED; - return NULL; - } - - if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) { - // TODO: Read Hello Message customized data - // Just for future use, should not happen now - } - - if (!HelloNegotiationValid(remote_msg)) { + if (!negotiated) { LOG(WARNING) << "Fail to negotiate with server, fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { - ep->_remote_recv_block_size = remote_msg.block_size; - ep->_local_window_capacity = - std::min(ep->_sq_size, remote_msg.rq_size) - RESERVED_WR_NUM; - ep->_remote_window_capacity = - std::min(ep->_rq_size, remote_msg.sq_size) - RESERVED_WR_NUM; - ep->_sq_imm_window_size = RESERVED_WR_NUM; - ep->_remote_rq_window_size.store( - ep->_local_window_capacity, butil::memory_order_relaxed); - ep->_sq_window_size.store( - ep->_local_window_capacity, butil::memory_order_relaxed); - + ep->ApplyRemoteHello(remote); ep->_state = C_BRINGUP_QP; - if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) { - LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description(); + if (ep->BringUpQp(remote.lid, remote.gid, remote.qp_num) < 0) { + LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" + << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { rdma_transport->_rdma_state = RdmaTransport::RDMA_ON; @@ -542,28 +527,26 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { // Send ACK message to server ep->_state = C_ACK_SEND; - uint32_t flags = 0; - if (rdma_transport->_rdma_state != RdmaTransport::RDMA_OFF) { - flags |= ACK_MSG_RDMA_OK; - } - uint32_t* tmp = (uint32_t*)data; // avoid GCC warning on strict-aliasing - *tmp = butil::HostToNet32(flags); - if (ep->WriteToFd(data, ACK_MSG_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to send Ack Message to server:" << s->description(); + uint32_t flags = rdma_transport->_rdma_state != RdmaTransport::RDMA_OFF ? HELLO_ACK_RDMA_OK : 0; + uint32_t flags_be = butil::HostToNet32(flags); + if (ep->WriteToFd(&flags_be, HELLO_ACK_LEN) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to send Ack Message to server:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } if (rdma_transport->_rdma_state == RdmaTransport::RDMA_ON) { ep->_state = ESTABLISHED; - LOG_IF(INFO, FLAGS_rdma_trace_verbose) - << "Client handshake ends (use rdma) on " << s->description(); + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Client handshake ends (use rdma v" << ep->_handshake_version + << ") on " << s->description(); } else { ep->_state = FALLBACK_TCP; - LOG_IF(INFO, FLAGS_rdma_trace_verbose) + LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Client handshake ends (use tcp) on " << s->description(); } @@ -572,77 +555,75 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) { return NULL; } +// Server-side handshake entry: the state machine. +// +// S_HELLO_WAIT (read magic + dispatch + hs->ReceiveAndParseRemoteHello) +// | +// v +// [negotiation: ApplyRemoteHello + S_ALLOC_QPCQ + S_BRINGUP_QP] +// | +// v +// S_HELLO_SEND (hs->SendLocalHello) +// | +// v +// S_ACK_WAIT +// | +// v +// ESTABLISHED / FALLBACK_TCP void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { - RdmaEndpoint* ep = static_cast(arg); + auto ep = static_cast(arg); SocketUniquePtr s(ep->_socket); + auto rdma_transport = static_cast(s->_transport.get()); - LOG_IF(INFO, FLAGS_rdma_trace_verbose) + LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Start handshake on " << s->description(); - uint8_t data[g_rdma_hello_msg_len]; - ep->_state = S_HELLO_WAIT; - if (ep->ReadFromFd(data, MAGIC_STR_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to read Hello Message from client:" << s->description() << " " << s->_remote_side; + uint8_t magic[MAGIC_STR_LEN]; + if (ep->ReadFromFd(magic, MAGIC_STR_LEN) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to read Hello Message from client:" + << s->description() << " " << s->_remote_side; s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - auto* rdma_transport = static_cast(s->_transport.get()); - if (memcmp(data, MAGIC_STR, MAGIC_STR_LEN) != 0) { - LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "It seems that the " - << "client does not use RDMA, fallback to TCP:" + + // Dispatch on magic, or fall back to TCP + std::unique_ptr handshake = CreateServerHandshakeByMagic(ep, magic); + if (!handshake) { + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "It seems that the client does not use RDMA, fallback to TCP:" << s->description(); - // we need to copy data read back to _socket->_read_buf - s->_read_buf.append(data, MAGIC_STR_LEN); + // We need to copy data read back to _socket->_read_buf. + s->_read_buf.append(magic, MAGIC_STR_LEN); ep->_state = FALLBACK_TCP; rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; ep->TryReadOnTcp(); return NULL; } + ep->_handshake_version = handshake->ProtocolVersion(); - if (ep->ReadFromFd(data, g_rdma_hello_msg_len - MAGIC_STR_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to read Hello Message from client:" << s->description(); + // Magic was already consumed above; the subclass MUST NOT re-read it. + ParsedHello remote{}; + bool negotiated = false; + if (handshake->ReceiveAndParseRemoteHello(&remote, &negotiated) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to receive hello from client:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - HelloMessage remote_msg; - remote_msg.Deserialize(data); - if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) { - LOG(WARNING) << "Fail to parse Hello Message length from client:" - << s->description(); - s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(EPROTO)); - ep->_state = FAILED; - return NULL; - } - if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) { - // TODO: Read Hello Message customized header - // Just for future use, should not happen now - } - - if (!HelloNegotiationValid(remote_msg)) { + if (!negotiated) { LOG(WARNING) << "Fail to negotiate with client, fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { - ep->_remote_recv_block_size = remote_msg.block_size; - ep->_local_window_capacity = - std::min(ep->_sq_size, remote_msg.rq_size) - RESERVED_WR_NUM; - ep->_remote_window_capacity = - std::min(ep->_rq_size, remote_msg.sq_size) - RESERVED_WR_NUM; - ep->_sq_imm_window_size = RESERVED_WR_NUM; - ep->_remote_rq_window_size.store( - ep->_local_window_capacity, butil::memory_order_relaxed); - ep->_sq_window_size.store( - ep->_local_window_capacity, butil::memory_order_relaxed); - + ep->ApplyRemoteHello(remote); ep->_state = S_ALLOC_QPCQ; if (ep->AllocateResources() < 0) { LOG(WARNING) << "Fail to allocate rdma resources, fallback to tcp:" @@ -650,7 +631,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; } else { ep->_state = S_BRINGUP_QP; - if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) { + if (ep->BringUpQp(remote.lid, remote.gid, remote.qp_num) < 0) { LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description(); rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; @@ -658,73 +639,55 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) { } } - // Send hello message to client ep->_state = S_HELLO_SEND; - HelloMessage local_msg; - local_msg.msg_len = g_rdma_hello_msg_len; - if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) { - local_msg.impl_ver = 0; - local_msg.hello_ver = 0; - } else { - local_msg.lid = GetRdmaLid(); - local_msg.gid = GetRdmaGid(); - local_msg.block_size = g_rdma_recv_block_size; - local_msg.sq_size = ep->_sq_size; - local_msg.rq_size = ep->_rq_size; - local_msg.hello_ver = g_rdma_hello_version; - local_msg.impl_ver = g_rdma_impl_version; - if (BAIDU_LIKELY(ep->_resource)) { - local_msg.qp_num = ep->_resource->qp->qp_num; - } else { - // Only happens in UT - local_msg.qp_num = 0; - } - } - memcpy(data, MAGIC_STR, 4); - local_msg.Serialize((char*)data + 4); - if (ep->WriteToFd(data, g_rdma_hello_msg_len) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to send Hello Message to client:" << s->description(); + if (handshake->SendLocalHello() < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to send Hello Message to client:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } - // Recv ACK Message ep->_state = S_ACK_WAIT; - if (ep->ReadFromFd(data, ACK_MSG_LEN) < 0) { - const int saved_errno = errno; - PLOG(WARNING) << "Fail to read ack message from client:" << s->description(); + uint32_t flags_be = 0; + if (ep->ReadFromFd(&flags_be, HELLO_ACK_LEN) < 0) { + int saved_errno = errno; + PLOG(WARNING) << "Fail to read ack message from client:" + << s->description(); s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(saved_errno)); + s->description().c_str(), berror(saved_errno)); ep->_state = FAILED; return NULL; } + uint32_t flags = butil::NetToHost32(flags_be); + bool client_ack_ok = (flags & HELLO_ACK_RDMA_OK) != 0; - // Check RDMA enable flag - uint32_t* tmp = (uint32_t*)data; // avoid GCC warning on strict-aliasing - uint32_t flags = butil::NetToHost32(*tmp); - if (flags & ACK_MSG_RDMA_OK) { + if (client_ack_ok) { if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) { - LOG(WARNING) << "Fail to parse Hello Message length from client:" - << s->description(); + // Client asked for RDMA but we are falling back: protocol + // breakdown, abort the connection so the client sees a + // clean error rather than a half-up RDMA channel. + LOG(WARNING) << "Client wants RDMA in ACK but server is in " + << "RDMA_OFF state: " << s->description(); s->SetFailed(EPROTO, "Fail to complete rdma handshake from %s: %s", - s->description().c_str(), berror(EPROTO)); + s->description().c_str(), berror(EPROTO)); ep->_state = FAILED; return NULL; - } else { - rdma_transport->_rdma_state = RdmaTransport::RDMA_ON; - ep->_state = ESTABLISHED; - LOG_IF(INFO, FLAGS_rdma_trace_verbose) - << "Server handshake ends (use rdma) on " << s->description(); } + rdma_transport->_rdma_state = RdmaTransport::RDMA_ON; + ep->_state = ESTABLISHED; + LOG_IF(INFO, FLAGS_rdma_trace_verbose) + << "Server handshake ends (use rdma v" << ep->_handshake_version + << ") on " << s->description(); } else { rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF; ep->_state = FALLBACK_TCP; - LOG_IF(INFO, FLAGS_rdma_trace_verbose) + LOG_IF(INFO, FLAGS_rdma_trace_verbose) << "Server handshake ends (use tcp) on " << s->description(); } + ep->TryReadOnTcp(); return NULL; @@ -1076,7 +1039,7 @@ int RdmaEndpoint::PostRecv(uint32_t num, bool zerocopy) { PLOG(WARNING) << "Fail to allocate rbuf"; return -1; } else { - CHECK(static_cast(size) == g_rdma_recv_block_size) << size; + CHECK_EQ(static_cast(size), g_rdma_recv_block_size); } } if (DoPostRecv(_rbuf_data[_rq_received], g_rdma_recv_block_size) < 0) { @@ -1645,6 +1608,7 @@ std::string RdmaEndpoint::GetStateStr() const { void RdmaEndpoint::DebugInfo(std::ostream& os, butil::StringPiece connector) const { os << "rdma_state=ON" << connector << "handshake_state=" << GetStateStr() + << connector << "handshake_version=" << static_cast(_handshake_version) << connector << "rdma_sq_imm_window_size=" << _sq_imm_window_size << connector << "rdma_remote_rq_window_size=" << _remote_rq_window_size.load(butil::memory_order_relaxed) << connector << "rdma_sq_window_size=" << _sq_window_size.load(butil::memory_order_relaxed) diff --git a/src/brpc/rdma/rdma_endpoint.h b/src/brpc/rdma/rdma_endpoint.h index 54a008f1f7..7b6652bc86 100644 --- a/src/brpc/rdma/rdma_endpoint.h +++ b/src/brpc/rdma/rdma_endpoint.h @@ -40,6 +40,24 @@ DECLARE_bool(rdma_use_polling); DECLARE_int32(rdma_poller_num); DECLARE_bool(rdma_disable_bthread); +class RdmaHandshakeClientV2; +class RdmaHandshakeServerV2; +class RdmaHandshakeClientV3; +class RdmaHandshakeServerV3; +struct ParsedHello; +class RdmaHello; +class RdmaEndpoint; +namespace v2_wire { + int ReadBodyAndNegotiate(RdmaEndpoint* ep, ParsedHello* remote, bool* negotiated); + int DrainBytes(RdmaEndpoint* ep, size_t n); +} // namespace v2_wire + +namespace v3_wire { + void FillLocalRdmaHello(const RdmaEndpoint* ep, RdmaHello* msg); + int ReadAndParseV3Hello(RdmaEndpoint* ep, RdmaHello* out); + int WriteV3Hello(RdmaEndpoint* ep, const RdmaHello& msg); +} // namespace v3_wire + class RdmaConnect : public AppConnect { public: void StartConnect(const Socket* socket, @@ -74,6 +92,15 @@ struct RdmaResource { class BAIDU_CACHELINE_ALIGNMENT RdmaEndpoint : public SocketUser { friend class RdmaConnect; friend class Socket; +friend class RdmaHandshakeClientV2; +friend class RdmaHandshakeServerV2; +friend class RdmaHandshakeClientV3; +friend class RdmaHandshakeServerV3; +friend int v2_wire::ReadBodyAndNegotiate(RdmaEndpoint*, ParsedHello*, bool*); +friend int v2_wire::DrainBytes(RdmaEndpoint*, size_t); +friend void v3_wire::FillLocalRdmaHello(const RdmaEndpoint*, RdmaHello*); +friend int v3_wire::ReadAndParseV3Hello(RdmaEndpoint*, RdmaHello*); +friend int v3_wire::WriteV3Hello(RdmaEndpoint*, const RdmaHello&); public: explicit RdmaEndpoint(Socket* s); ~RdmaEndpoint() override; @@ -181,6 +208,7 @@ friend class Socket; // wait for _read_butex if encounter EAGAIN // return -1 if encounter other errno (including EOF) int ReadFromFd(void* data, size_t len); + int ReadFromFd(butil::IOPortal* data, size_t len); // Write at most len bytes from data to fd in _socket @@ -188,6 +216,17 @@ friend class Socket; // return -1 if encounter other errno int WriteToFd(void* data, size_t len); + // Write data to fd in _socket. + // wait for _epollout_butex if encounter EAGAIN. + // return -1 if encounter other errno. + int WriteToFd(butil::IOBuf* data); + + // Copy negotiated remote parameters into the endpoint and compute + // the SQ/RQ window capacities. Called by both + // ProcessHandshakeAtClient and ProcessHandshakeAtServer after the + // peer's hello has been validated. + void ApplyRemoteHello(const ParsedHello& remote); + // Bringup the QP from RESET state to RTS state // Arguments: // lid: remote LID @@ -225,6 +264,13 @@ friend class Socket; // State of Handshake State _state; + // Wire-level handshake protocol version (set by dispatch in + // ProcessHandshakeAtClient/Server). Aligned with the protocol code: + // 0 = unnegotiated + // 2 = v2 "RDMA" + // 3 = v3 "RDM3" + int _handshake_version; + // rdma resource RdmaResource* _resource; diff --git a/src/brpc/rdma/rdma_handshake.cpp b/src/brpc/rdma/rdma_handshake.cpp new file mode 100644 index 0000000000..9bd2312ec4 --- /dev/null +++ b/src/brpc/rdma/rdma_handshake.cpp @@ -0,0 +1,408 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#if BRPC_WITH_RDMA + +#include "brpc/rdma/rdma_handshake.h" + +#include +#include // std::min +#include +#include +#include +#include "butil/iobuf.h" // IOBuf, IOPortal, IOBufAsZeroCopy*Stream +#include "butil/sys_byteorder.h" +#include "brpc/socket.h" +#include "brpc/rdma/rdma_endpoint.h" +#include "brpc/rdma/rdma_helper.h" +#include "brpc/rdma_transport.h" +#include "brpc/rdma/rdma_handshake.pb.h" + +namespace brpc { +namespace rdma { + +DEFINE_int32(rdma_client_handshake_version, 2, + "RDMA handshake protocol version used by client. " + "2 = legacy 'RDMA' magic (default, compatible with all servers); " + "3 = new 'RDM3' protobuf-based handshake " + "(MUST only be enabled after target servers support v3)."); + +extern const uint16_t MIN_QP_SIZE; +extern const uint16_t MIN_BLOCK_SIZE; +extern uint32_t g_rdma_recv_block_size; +extern bool g_skip_rdma_init; + +// Wire-level constants for the v2 handshake. +static const char* MAGIC_STR = "RDMA"; +static constexpr uint16_t RDMA_HELLO_V2_MSG_LEN = 40; // In Byte +extern const uint16_t RDMA_HELLO_V2_VERSION = 2; +extern const uint16_t RDMA_IMPL_V2_VERSION = 1; + +// Wire-level constants for the v3 handshake. +static const char* MAGIC_STR_V3 = "RDM3"; +static const size_t RDMA_HELLO_V3_PB_SIZE_LEN = 4; +static const size_t RDMA_HELLO_V3_MAX_PB_SIZE = 4096; + +namespace v2_wire { + +void HelloMessage::Serialize(void* data) const { + uint16_t* current_pos = (uint16_t*)data; + *(current_pos++) = butil::HostToNet16(msg_len); + *(current_pos++) = butil::HostToNet16(hello_ver); + *(current_pos++) = butil::HostToNet16(impl_ver); + uint32_t* block_size_pos = (uint32_t*)current_pos; + *block_size_pos = butil::HostToNet32(block_size); + current_pos += 2; // move forward 4 Bytes + *(current_pos++) = butil::HostToNet16(sq_size); + *(current_pos++) = butil::HostToNet16(rq_size); + *(current_pos++) = butil::HostToNet16(lid); + fast_memcpy(current_pos, gid.raw, 16); + uint32_t* qp_num_pos = (uint32_t*)((char*)current_pos + 16); + *qp_num_pos = butil::HostToNet32(qp_num); +} + +void HelloMessage::Deserialize(void* data) { + uint16_t* current_pos = (uint16_t*)data; + msg_len = butil::NetToHost16(*current_pos++); + hello_ver = butil::NetToHost16(*current_pos++); + impl_ver = butil::NetToHost16(*current_pos++); + block_size = butil::NetToHost32(*(uint32_t*)current_pos); + current_pos += 2; // move forward 4 Bytes + sq_size = butil::NetToHost16(*current_pos++); + rq_size = butil::NetToHost16(*current_pos++); + lid = butil::NetToHost16(*current_pos++); + fast_memcpy(gid.raw, current_pos, 16); + qp_num = butil::NetToHost32(*(uint32_t*)((char*)current_pos + 16)); +} + +static bool ValidHelloMessage(const HelloMessage& msg) { + return msg.hello_ver == RDMA_HELLO_V2_VERSION && + msg.impl_ver == RDMA_IMPL_V2_VERSION && + msg.block_size >= MIN_BLOCK_SIZE && + msg.sq_size >= MIN_QP_SIZE && + msg.rq_size >= MIN_QP_SIZE; +} + +static void TranslateV2Hello(const HelloMessage& msg, ParsedHello* out) { + out->block_size = msg.block_size; + out->sq_size = msg.sq_size; + out->rq_size = msg.rq_size; + out->lid = msg.lid; + out->gid = msg.gid; + out->qp_num = msg.qp_num; +} + +int ReadBodyAndNegotiate(RdmaEndpoint* ep, ParsedHello* remote, bool* negotiated) { + uint8_t data[HELLO_MSG_LEN_MIN]; + if (ep->ReadFromFd(data, HELLO_MSG_LEN_MIN - MAGIC_STR_LEN) < 0) { + return -1; + } + HelloMessage remote_msg{}; + remote_msg.Deserialize(data); + if (remote_msg.msg_len < HELLO_MSG_LEN_MIN || + remote_msg.msg_len > HELLO_MSG_LEN_MAX) { + errno = EPROTO; + return -1; + } + if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) { + // Drain unknown trailing bytes so they don't pollute subsequent + // reads (e.g. the upcoming ACK message). v2 base fields already + // carry enough information for negotiation; unknown trailing + // bytes are treated as optional hints that v2 safely ignores. + size_t ext_len = remote_msg.msg_len - HELLO_MSG_LEN_MIN; + if (DrainBytes(ep, ext_len) < 0) { + return -1; + } + } + if (!ValidHelloMessage(remote_msg)) { + *negotiated = false; + return 0; + } + *negotiated = true; + TranslateV2Hello(remote_msg, remote); + return 0; +} + +int DrainBytes(RdmaEndpoint* ep, size_t n) { + uint8_t scratch[64]; + while (n > 0) { + size_t chunk = std::min(n, sizeof(scratch)); + if (ep->ReadFromFd(scratch, chunk) < 0) { + return -1; + } + n -= chunk; + } + return 0; +} + +} // namespace v2_wire + +int RdmaHandshakeClientV2::SendLocalHello() { + RdmaEndpoint* ep = _ep; + uint8_t data[RDMA_HELLO_V2_MSG_LEN]; + + v2_wire::HelloMessage local_msg{}; + local_msg.msg_len = RDMA_HELLO_V2_MSG_LEN; + local_msg.hello_ver = RDMA_HELLO_V2_VERSION; + local_msg.impl_ver = RDMA_IMPL_V2_VERSION; + local_msg.block_size = g_rdma_recv_block_size; + local_msg.sq_size = ep->_sq_size; + local_msg.rq_size = ep->_rq_size; + local_msg.lid = GetRdmaLid(); + local_msg.gid = GetRdmaGid(); + if (BAIDU_LIKELY(ep->_resource)) { + local_msg.qp_num = ep->_resource->qp->qp_num; + } else { + // Only happens in UT + local_msg.qp_num = 0; + } + fast_memcpy(data, MAGIC_STR, 4); + local_msg.Serialize((char*)data + 4); + return ep->WriteToFd(data, RDMA_HELLO_V2_MSG_LEN); +} + +int RdmaHandshakeClientV2::ReceiveAndParseRemoteHello(ParsedHello* remote, + bool* negotiated) { + RdmaEndpoint* ep = _ep; + + // Read and verify magic (the endpoint did NOT pre-read magic on the client side). + uint8_t magic[MAGIC_STR_LEN]; + if (ep->ReadFromFd(magic, MAGIC_STR_LEN) < 0) { + return -1; + } + if (memcmp(magic, MAGIC_STR, MAGIC_STR_LEN) != 0) { + errno = EPROTO; + return -1; + } + return v2_wire::ReadBodyAndNegotiate(ep, remote, negotiated); +} + +int RdmaHandshakeServerV2::ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) { + // Magic already consumed by ProcessHandshakeAtServer. + return v2_wire::ReadBodyAndNegotiate(_ep, remote, negotiated); +} + +int RdmaHandshakeServerV2::SendLocalHello() { + uint8_t data[RDMA_HELLO_V2_MSG_LEN]; + v2_wire::HelloMessage local_msg{}; + local_msg.msg_len = RDMA_HELLO_V2_MSG_LEN; + auto rdma_transport = static_cast(_ep->_socket->_transport.get()); + if (rdma_transport->_rdma_state == RdmaTransport::RDMA_OFF) { + local_msg.hello_ver = 0; + local_msg.impl_ver = 0; + local_msg.block_size = 0; + local_msg.sq_size = 0; + local_msg.rq_size = 0; + local_msg.lid = 0; + memset(local_msg.gid.raw, 0, sizeof(local_msg.gid.raw)); + local_msg.qp_num = 0; + } else { + local_msg.hello_ver = RDMA_HELLO_V2_VERSION; + local_msg.impl_ver = RDMA_IMPL_V2_VERSION; + local_msg.block_size = g_rdma_recv_block_size; + local_msg.sq_size = _ep->_sq_size; + local_msg.rq_size = _ep->_rq_size; + local_msg.lid = GetRdmaLid(); + local_msg.gid = GetRdmaGid(); + if (BAIDU_LIKELY(_ep->_resource)) { + local_msg.qp_num = _ep->_resource->qp->qp_num; + } else { + // Only happens in UT + local_msg.qp_num = 0; + } + } + fast_memcpy(data, MAGIC_STR, 4); + local_msg.Serialize((char*)data + 4); + return _ep->WriteToFd(data, RDMA_HELLO_V2_MSG_LEN); +} + +namespace v3_wire { + +bool ValidRdmaHello(const RdmaHello& msg) { + if (msg.gid().size() != sizeof(ibv_gid)) { + return false; + } + // ParsedHello stores these as uint16_t; reject values that would truncate. + constexpr uint16_t MAX_UINT16 = std::numeric_limits::max(); + if (msg.sq_size() > MAX_UINT16 || msg.rq_size() > MAX_UINT16 || msg.lid() > MAX_UINT16) { + return false; + } + if (msg.block_size() < MIN_BLOCK_SIZE) { + return false; + } + if (msg.sq_size() < MIN_QP_SIZE) { + return false; + } + if (msg.rq_size() < MIN_QP_SIZE) { + return false; + } + // qp_num == 0 only happens in UT (no real QP allocated). + if (msg.qp_num() == 0 && !g_skip_rdma_init) { + return false; + } + return true; +} + +void FillLocalRdmaHello(const RdmaEndpoint* ep, RdmaHello* msg) { + msg->set_block_size(g_rdma_recv_block_size); + msg->set_sq_size(ep->_sq_size); + msg->set_rq_size(ep->_rq_size); + msg->set_lid(GetRdmaLid()); + ibv_gid gid = GetRdmaGid(); + msg->set_gid(std::string(reinterpret_cast(gid.raw), + sizeof(gid.raw))); + if (BAIDU_LIKELY(ep->_resource)) { + msg->set_qp_num(ep->_resource->qp->qp_num); + } else { + // Only happens in UT + msg->set_qp_num(0); + } +} + +int ReadAndParseV3Hello(RdmaEndpoint* ep, RdmaHello* out) { + uint8_t size_buf[RDMA_HELLO_V3_PB_SIZE_LEN]; + if (ep->ReadFromFd(size_buf, RDMA_HELLO_V3_PB_SIZE_LEN) < 0) { + return -1; + } + uint32_t pb_size = butil::NetToHost32( + *reinterpret_cast(size_buf)); + if (pb_size == 0 || pb_size > RDMA_HELLO_V3_MAX_PB_SIZE) { + errno = EPROTO; + return -1; + } + butil::IOPortal body; + if (ep->ReadFromFd(&body, pb_size) < 0) { + return -1; + } + + butil::IOBufAsZeroCopyInputStream input(body); + if (!out->ParseFromZeroCopyStream(&input)) { + LOG(ERROR) << "Failed to parse RdmaHello"; + errno = EPROTO; + return -1; + } + return 0; +} + +int WriteV3Hello(RdmaEndpoint* ep, const RdmaHello& msg) { + uint32_t pb_size = static_cast(msg.ByteSizeLong()); + if (pb_size > RDMA_HELLO_V3_MAX_PB_SIZE) { + errno = EPROTO; + return -1; + } + + // [ "RDM3" 4B ][ pb_size 4B (big-endian) ][ RdmaHello protobuf bytes ] + butil::IOBuf packet; + packet.append(MAGIC_STR_V3, MAGIC_STR_LEN); + uint32_t pb_size_be = butil::HostToNet32(pb_size); + packet.append(&pb_size_be, RDMA_HELLO_V3_PB_SIZE_LEN); + butil::IOBufAsZeroCopyOutputStream output(&packet); + if (!msg.SerializeToZeroCopyStream(&output)) { + LOG(ERROR) << "Failed to serialize RdmaHello"; + errno = EPROTO; + return -1; + } + return ep->WriteToFd(&packet); +} + +void TranslateHello(const RdmaHello& msg, ParsedHello* out) { + out->block_size = msg.block_size(); + out->sq_size = static_cast(msg.sq_size()); + out->rq_size = static_cast(msg.rq_size()); + out->lid = static_cast(msg.lid()); + fast_memcpy(out->gid.raw, msg.gid().data(), sizeof(out->gid.raw)); + out->qp_num = msg.qp_num(); +} + +} // namespace v3_wire + +int RdmaHandshakeClientV3::SendLocalHello() { + RdmaHello local_msg{}; + v3_wire::FillLocalRdmaHello(_ep, &local_msg); + return v3_wire::WriteV3Hello(_ep, local_msg); +} + +int RdmaHandshakeClientV3::ReceiveAndParseRemoteHello(ParsedHello* remote, + bool* negotiated) { + uint8_t magic[MAGIC_STR_LEN]; + if (_ep->ReadFromFd(magic, MAGIC_STR_LEN) < 0) { + return -1; + } + if (memcmp(magic, MAGIC_STR_V3, MAGIC_STR_LEN) != 0) { + errno = EPROTO; + return -1; + } + + RdmaHello remote_msg{}; + if (v3_wire::ReadAndParseV3Hello(_ep, &remote_msg) < 0) { + return -1; + } + if (!v3_wire::ValidRdmaHello(remote_msg)) { + *negotiated = false; + return 0; + } + *negotiated = true; + v3_wire::TranslateHello(remote_msg, remote); + return 0; +} + +int RdmaHandshakeServerV3::ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) { + // Magic already consumed by ProcessHandshakeAtServer. + RdmaHello remote_msg{}; + if (v3_wire::ReadAndParseV3Hello(_ep, &remote_msg) < 0) { + return -1; + } + if (!v3_wire::ValidRdmaHello(remote_msg)) { + *negotiated = false; + return 0; + } + *negotiated = true; + v3_wire::TranslateHello(remote_msg, remote); + return 0; +} + +int RdmaHandshakeServerV3::SendLocalHello() { + RdmaHello local_msg{}; + v3_wire::FillLocalRdmaHello(_ep, &local_msg); + return v3_wire::WriteV3Hello(_ep, local_msg); +} + +std::unique_ptr CreateClientHandshake(RdmaEndpoint* ep) { + switch (FLAGS_rdma_client_handshake_version) { + case 3: + return std::unique_ptr(new RdmaHandshakeClientV3(ep)); + case 2: + default: + return std::unique_ptr(new RdmaHandshakeClientV2(ep)); + } +} + +std::unique_ptr CreateServerHandshakeByMagic( + RdmaEndpoint* ep, const uint8_t magic[MAGIC_STR_LEN]) { + if (memcmp(magic, MAGIC_STR, MAGIC_STR_LEN) == 0) { + return std::unique_ptr(new RdmaHandshakeServerV2(ep)); + } + if (memcmp(magic, MAGIC_STR_V3, MAGIC_STR_LEN) == 0) { + return std::unique_ptr(new RdmaHandshakeServerV3(ep)); + } + return nullptr; +} + +} // namespace rdma +} // namespace brpc + +#endif // BRPC_WITH_RDMA diff --git a/src/brpc/rdma/rdma_handshake.h b/src/brpc/rdma/rdma_handshake.h new file mode 100644 index 0000000000..5f36a9e6e2 --- /dev/null +++ b/src/brpc/rdma/rdma_handshake.h @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_RDMA_HANDSHAKE_H +#define BRPC_RDMA_HANDSHAKE_H + +#if BRPC_WITH_RDMA + +#include +#include +#include +#include +#include "butil/macros.h" + +namespace brpc { +namespace rdma { + +class RdmaEndpoint; + +// Length of the RDMA handshake magic string (e.g. "RDMA", "RDM3"). +static const size_t MAGIC_STR_LEN = 4; + +// Wire-format-agnostic representation of a peer's hello message. +// Each protocol version (v2 binary, v3 protobuf) translates its own +// wire format into this struct so the state-machine driver in +// RdmaEndpoint::ProcessHandshakeAt{Client,Server} stays free of any +// wire-format details. +struct ParsedHello { + uint32_t block_size; + uint16_t sq_size; + uint16_t rq_size; + uint16_t lid; + ibv_gid gid; + uint32_t qp_num; +}; + +namespace v2_wire { + +// Wire constants for the v2 hello. +// +// HELLO_MSG_LEN_MIN: total length of the base v2 hello (4B magic + +// 36B HelloMessage). Anything shorter than this is malformed. +// HELLO_MSG_LEN_MAX: upper bound for the entire v2 hello message +// length declared by HelloMessage::msg_len. Anything beyond this is +// treated as a protocol error and the connection is closed without +// attempting to drain. +static constexpr size_t HELLO_MSG_LEN_MIN = 40; +static constexpr size_t HELLO_MSG_LEN_MAX = 4096; + +// v2 binary HelloMessage. +struct HelloMessage { + void Serialize(void* data) const; + void Deserialize(void* data); + + uint16_t msg_len; + uint16_t hello_ver; + uint16_t impl_ver; + uint32_t block_size; + uint16_t sq_size; + uint16_t rq_size; + uint16_t lid; + ibv_gid gid; + uint32_t qp_num; +}; + +} // namespace v2_wire + +// Abstract base class of an RDMA handshake. +// +// Acts as the protocol-version dispatch point for the state machine +// driven by RdmaEndpoint::ProcessHandshakeAt{Client,Server}. +class RdmaHandshake { +public: + explicit RdmaHandshake(RdmaEndpoint* ep) : _ep(ep) {} + virtual ~RdmaHandshake() = default; + + DISALLOW_COPY_AND_ASSIGN(RdmaHandshake); + + // Wire-level protocol version (2 for "RDMA", 3 for "RDM3"). + virtual int ProtocolVersion() const = 0; + + // Build and send the local hello (including the protocol magic). + // Returns 0 on success, -1 on IO error (errno set). + // + // For a server in fallback state, implementations MUST still + // produce a sendable message; each version uses its own wire + // convention to signal "I am falling back" to the peer: + // - v2: zero hello_ver/impl_ver so the peer's HelloNegotiationValid + // rejects it; + // - v3: qp_num==0 so the peer's ValidRdmaHello rejects it. + virtual int SendLocalHello() = 0; + + // Read the peer's hello, validate it, and translate into ParsedHello. + // + // Role-specific semantics: + // - Client subclasses: read & verify the 4B magic first, then the + // body. (The endpoint did NOT pre-read the magic on the client + // side.) + // - Server subclasses: read ONLY the body. The 4B magic was + // already consumed by ProcessHandshakeAtServer and was used to + // pick `this` from CreateServerHandshakeByMagic; re-reading + // would deadlock. + // + // Outputs: + // *negotiated -- true if the remote hello is structurally valid + // AND passes per-protocol negotiation checks; + // false means the peer asked for fallback or sent + // something we can't honor. + // Returns: + // 0 -- IO/parsing layer OK; check *negotiated and *remote. + // -1 -- IO error or unrecoverable protocol error (errno set). + virtual int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) = 0; + +protected: + RdmaEndpoint* _ep; +}; + +// v2 handshake (legacy "RDMA" magic, 36B binary HelloMessage). +class RdmaHandshakeClientV2 : public RdmaHandshake { +public: + using RdmaHandshake::RdmaHandshake; + int ProtocolVersion() const override { return 2; } + + int SendLocalHello() override; + int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) override; +}; + +class RdmaHandshakeServerV2 : public RdmaHandshake { +public: + using RdmaHandshake::RdmaHandshake; + int ProtocolVersion() const override { return 2; } + + int SendLocalHello() override; + int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) override; +}; + +// v3 handshake (new "RDM3" magic, protobuf RdmaHello). +// [ "RDM3" 4B ][ pb_size 4B (big-endian) ][ RdmaHello protobuf bytes ] +class RdmaHandshakeClientV3 : public RdmaHandshake { +public: + using RdmaHandshake::RdmaHandshake; + int ProtocolVersion() const override { return 3; } + + int SendLocalHello() override; + int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) override; +}; + +class RdmaHandshakeServerV3 : public RdmaHandshake { +public: + using RdmaHandshake::RdmaHandshake; + int ProtocolVersion() const override { return 3; } + + int SendLocalHello() override; + int ReceiveAndParseRemoteHello(ParsedHello* remote, bool* negotiated) override; +}; + +// Factory methods +// +// Pick the client-side handshake based on +// FLAGS_rdma_client_handshake_version: +// 2 (default) -> RdmaHandshakeClientV2 +// 3 -> RdmaHandshakeClientV3 +// Other values fall back to V2. +std::unique_ptr CreateClientHandshake(RdmaEndpoint* ep); + +// Pick the server-side handshake based on the 4B magic already read. +// Returns NULL if `magic` is not a recognized RDMA magic +// (the caller should then fallback to TCP). +// "RDMA" -> RdmaHandshakeServerV2 +// "RDM3" -> RdmaHandshakeServerV3 +std::unique_ptr CreateServerHandshakeByMagic( + RdmaEndpoint* ep, const uint8_t magic[MAGIC_STR_LEN]); + +} // namespace rdma +} // namespace brpc + +#endif // BRPC_WITH_RDMA +#endif // BRPC_RDMA_HANDSHAKE_H diff --git a/src/brpc/rdma/rdma_handshake.proto b/src/brpc/rdma/rdma_handshake.proto new file mode 100644 index 0000000000..c180b58b96 --- /dev/null +++ b/src/brpc/rdma/rdma_handshake.proto @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +syntax = "proto2"; + +package brpc.rdma; + +option cc_generic_services = false; + +// RDMA handshake v3 message. +// Carried in the body of every "RDM3" handshake packet: +// +// [ "RDM3" 4B ][ pb_size 4B ][ RdmaHello protobuf bytes ] +message RdmaHello { + // ---- v2-parity base fields (required) ---- + // Listed first and in the same logical order as the v2 binary + // HelloMessage (minus hello_ver / impl_ver, which are subsumed by + // the wrapper magic "RDM3"). Keeping the same ordering simplifies + // side-by-side reasoning when debugging mixed v2/v3 traffic. + // + // Marked `required` because the handshake cannot proceed without + // any of these; ParseFromArray() will reject a missing field at + // the protobuf layer, so we don't need an extra has_xxx() check + // in RdmaHelloValid() for presence. + required uint32 block_size = 1; + required uint32 sq_size = 2; + required uint32 rq_size = 3; + required uint32 lid = 4; + // Must be exactly 16 bytes (sizeof(ibv_gid)). + required bytes gid = 5; + required uint32 qp_num = 6; +} diff --git a/src/brpc/rdma_transport.h b/src/brpc/rdma_transport.h index 65ae88f7a6..d8520b1a6d 100644 --- a/src/brpc/rdma_transport.h +++ b/src/brpc/rdma_transport.h @@ -25,9 +25,10 @@ namespace brpc { class RdmaTransport : public Transport { - friend class TransportFactory; - friend class rdma::RdmaEndpoint; - friend class rdma::RdmaConnect; +friend class TransportFactory; +friend class rdma::RdmaEndpoint; +friend class rdma::RdmaConnect; +friend class rdma::RdmaHandshakeServerV2; public: void Init(Socket* socket, const SocketOptions& options) override; void Release() override; @@ -47,7 +48,7 @@ class RdmaTransport : public Transport { private: static bool OptionsAvailableForRdma(const ChannelOptions* opt); static bool OptionsAvailableOverRdma(const ServerOptions* opt); -private: + // The on/off state of RDMA enum RdmaState { RDMA_ON, diff --git a/src/brpc/server.cpp b/src/brpc/server.cpp index 935e5f1bb1..57f665da91 100644 --- a/src/brpc/server.cpp +++ b/src/brpc/server.cpp @@ -855,26 +855,37 @@ int Server::StartInternal(const butil::EndPoint& endpoint, return -1; } - copy_and_fill_server_options(_options, opt ? *opt : ServerOptions()); - - if (!_options.h2_settings.IsValid(true/*log_error*/)) { + // Validate the user-provided ServerOptions BEFORE + // copy_and_fill_server_options below. This is important: + // copy_and_fill_server_options unconditionally transfers ownership of + // user-provided pointers (nshead_service, thrift_service, ...) into + // _options. If we instead validated against _options after the copy, + // a failed Start() would leave fake/invalid pointers behind in + // _options, and the NEXT Start() would attempt to `delete` them via + // FREE_PTR_IF_NOT_REUSED, crashing (see RdmaTest.server_option_invalid). + const ServerOptions default_opt; + const ServerOptions& real_opt = opt ? *opt : default_opt; + + if (!real_opt.h2_settings.IsValid(true/*log_error*/)) { LOG(ERROR) << "Invalid h2_settings"; return -1; } - if (_options.bthread_tag < BTHREAD_TAG_DEFAULT || - _options.bthread_tag >= FLAGS_task_group_ntags) { - LOG(ERROR) << "Fail to set tag " << _options.bthread_tag + if (real_opt.bthread_tag < BTHREAD_TAG_DEFAULT || + real_opt.bthread_tag >= FLAGS_task_group_ntags) { + LOG(ERROR) << "Fail to set tag " << real_opt.bthread_tag << ", tag range is [" << BTHREAD_TAG_DEFAULT << ":" << FLAGS_task_group_ntags << ")"; return -1; } - int ret = TransportFactory::ContextInitOrDie(_options.socket_mode, true, &_options); + int ret = TransportFactory::ContextInitOrDie(real_opt.socket_mode, true, &real_opt); if (ret != 0) { LOG(ERROR) << "Fail to initialize transport context for server, ret=" << ret; return -1; } + copy_and_fill_server_options(_options, real_opt); + if (_options.http_master_service) { // Check requirements for http_master_service: // has "default_method" & request/response have no fields diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp index 0ca6950428..a3d43fa3b8 100644 --- a/src/brpc/socket.cpp +++ b/src/brpc/socket.cpp @@ -1554,8 +1554,7 @@ void Socket::CheckConnectedAndKeepWrite(int fd, int err, void* data) { g_vars->channel_conn << 1; } if (s->_app_connect) { - s->_app_connect->StartConnect(req->get_socket(), - AfterAppConnected, req); + s->_app_connect->StartConnect(req->get_socket(), AfterAppConnected, req); } else { // Successfully created a connection AfterAppConnected(0, req); diff --git a/src/brpc/socket.h b/src/brpc/socket.h index 816fccdf27..7311d73895 100644 --- a/src/brpc/socket.h +++ b/src/brpc/socket.h @@ -56,6 +56,10 @@ class ChannelBalancer; namespace rdma { class RdmaEndpoint; class RdmaConnect; +class RdmaHandshakeClientV2; +class RdmaHandshakeServerV2; +class RdmaHandshakeClientV3; +class RdmaHandshakeServerV3; } class Socket; @@ -317,6 +321,10 @@ friend class policy::RtmpContext; friend class schan::ChannelBalancer; friend class rdma::RdmaEndpoint; friend class rdma::RdmaConnect; +friend class rdma::RdmaHandshakeClientV2; +friend class rdma::RdmaHandshakeServerV2; +friend class rdma::RdmaHandshakeClientV3; +friend class rdma::RdmaHandshakeServerV3; friend class HealthCheckTask; friend class OnAppHealthCheckDone; friend class HealthCheckManager; diff --git a/src/butil/thread_key.h b/src/butil/thread_key.h index c150528b63..77f346d608 100644 --- a/src/butil/thread_key.h +++ b/src/butil/thread_key.h @@ -18,6 +18,7 @@ #ifndef BUTIL_THREAD_KEY_H #define BUTIL_THREAD_KEY_H +#include #include #include #include diff --git a/test/brpc_rdma_unittest.cpp b/test/brpc_rdma_unittest.cpp index ccb280f1c8..43c6edfd12 100644 --- a/test/brpc_rdma_unittest.cpp +++ b/test/brpc_rdma_unittest.cpp @@ -24,7 +24,6 @@ #include #include "butil/endpoint.h" #include "butil/fd_guard.h" -#include "butil/fd_utility.h" #include "butil/iobuf.h" #include "butil/sys_byteorder.h" #include "butil/files/temp_file.h" @@ -36,15 +35,15 @@ #include "brpc/errno.pb.h" #include "brpc/parallel_channel.h" #include "brpc/selective_channel.h" +#include "brpc/rdma_transport.h" #include "brpc/rdma/block_pool.h" #include "brpc/rdma/rdma_endpoint.h" +#include "brpc/rdma/rdma_handshake.h" +#include "brpc/rdma/rdma_handshake.pb.h" #include "brpc/rdma/rdma_helper.h" #include "echo.pb.h" static const int PORT = 8713; -static const size_t RDMA_HELLO_MSG_LEN = 40; -static uint16_t RDMA_HELLO_VERSION = 2; -static uint16_t RDMA_IMPL_VERSION = 1; using namespace brpc; @@ -56,23 +55,13 @@ DEFINE_bool(rdma_test_enable, false, "Enable tests requring rdma runtime."); namespace rdma { -struct HelloMessage { - void Serialize(void* data) const; - void Deserialize(void* data); - - uint16_t msg_len; - uint16_t hello_ver; - uint16_t impl_ver; - uint32_t block_size; - uint16_t sq_size; - uint16_t rq_size; - uint16_t lid; - ibv_gid gid; - uint32_t qp_num; -}; +extern const uint16_t RDMA_HELLO_V2_VERSION; +extern const uint16_t RDMA_IMPL_V2_VERSION; DECLARE_bool(rdma_trace_verbose); DECLARE_int32(rdma_memory_pool_max_regions); +DECLARE_int32(rdma_client_handshake_version); + extern ibv_cq* (*IbvCreateCq)(ibv_context*, int, void*, ibv_comp_channel*, int); extern int (*IbvDestroyCq)(ibv_cq*); extern ibv_qp* (*IbvCreateQp)(ibv_pd*, ibv_qp_init_attr*); @@ -81,8 +70,8 @@ extern int (*IbvQueryQp)(ibv_qp*, ibv_qp_attr*, ibv_qp_attr_mask, ibv_qp_init_at extern int (*IbvDestroyQp)(ibv_qp*); extern butil::atomic g_rdma_available; extern bool g_skip_rdma_init; -} -} +} // namespace rdma +} // namespace brpc static std::string g_ip = "127.0.0.1"; static butil::EndPoint g_ep; @@ -109,7 +98,7 @@ class MyEchoService : public ::test::EchoService { LOG(INFO) << "sleep " << req->sleep_us() << "us..."; bthread_usleep(req->sleep_us()); } - res->set_message(req->message()); + res->set_message("MyEchoService"); if (req->code() != 0) { res->add_code_list(req->code()); } @@ -136,11 +125,12 @@ class RdmaTest : public ::testing::Test { rdma::DumpMemoryPoolInfo(std::cout); } -private: +protected: void StartServer(bool use_rdma = true) { ServerOptions options; - options.use_rdma = use_rdma; - options.idle_timeout_sec = 10; + options.enabled_protocols = "baidu_std"; + options.socket_mode = use_rdma ? SOCKET_MODE_RDMA : SOCKET_MODE_TCP; + options.idle_timeout_sec = 5; options.max_concurrency = 0; options.internal_port = -1; EXPECT_EQ(0, _server.Start(PORT, &options)); @@ -171,6 +161,29 @@ class RdmaTest : public ::testing::Test { MyEchoService _svc; }; +// Parameterized fixture used by upper-layer RPC tests that have no +// dependency on the handshake wire format. The parameter is the +// client-side handshake protocol version (FLAGS_rdma_client_handshake_version), +// so every TEST_P below is automatically executed once per supported +// version. Add a new version to INSTANTIATE_TEST_SUITE_P at the bottom +// of this file and these RPC tests will gain coverage for free. +class RdmaRpcTest : public RdmaTest, + public ::testing::WithParamInterface { +protected: + void SetUp() override { + RdmaTest::SetUp(); + _saved_handshake_version = rdma::FLAGS_rdma_client_handshake_version; + rdma::FLAGS_rdma_client_handshake_version = GetParam(); + } + void TearDown() override { + rdma::FLAGS_rdma_client_handshake_version = _saved_handshake_version; + RdmaTest::TearDown(); + } + +private: + int _saved_handshake_version = 2; +}; + TEST_F(RdmaTest, client_close_before_hello_send) { StartServer(); @@ -184,7 +197,7 @@ TEST_F(RdmaTest, client_close_before_hello_send) { ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg Socket* s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -205,13 +218,13 @@ TEST_F(RdmaTest, client_hello_msg_invalid_magic_str) { ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg Socket* s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); - uint8_t data[RDMA_HELLO_MSG_LEN]; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; memcpy(data, "PRPC", 4); // send as normal baidu_std protocol ASSERT_EQ(4, write(sockfd, data, 4)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); StopServer(); } @@ -231,11 +244,11 @@ TEST_F(RdmaTest, client_close_during_hello_send) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RD", 2); ASSERT_EQ(2, write(sockfd1, data, 2)); // break in magic str usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -245,11 +258,11 @@ TEST_F(RdmaTest, client_close_during_hello_send) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd2, data, 4)); // break after magic str usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd2); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -259,12 +272,12 @@ TEST_F(RdmaTest, client_close_during_hello_send) { ASSERT_EQ(0, connect(sockfd3, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); memset(data + 4, 0, 4); ASSERT_EQ(8, write(sockfd3, data, 8)); // break after magic str usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd3); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -280,18 +293,18 @@ TEST_F(RdmaTest, client_hello_msg_invalid_len) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - uint8_t data[RDMA_HELLO_MSG_LEN]; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; butil::fd_guard sockfd1(socket(AF_INET, SOCK_STREAM, 0)); ASSERT_TRUE(sockfd1 >= 0); ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); memset(data + 4, 0, 36); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); // Write invalid length. usleep(100000); // wait for server to handle the msg @@ -302,11 +315,11 @@ TEST_F(RdmaTest, client_hello_msg_invalid_len) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); uint16_t len = butil::HostToNet16(35); memcpy(data + 4, &len, sizeof(len)); memset(data + 6, 0, 34); @@ -325,8 +338,8 @@ TEST_F(RdmaTest, client_hello_msg_invalid_version) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - uint8_t data[RDMA_HELLO_MSG_LEN]; - uint16_t len = butil::HostToNet16(RDMA_HELLO_MSG_LEN); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + uint16_t len = butil::HostToNet16(rdma::v2_wire::HELLO_MSG_LEN_MIN); uint16_t ver = butil::HostToNet16(1); butil::fd_guard sockfd1(socket(AF_INET, SOCK_STREAM, 0)); @@ -334,22 +347,29 @@ TEST_F(RdmaTest, client_hello_msg_invalid_version) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data + 4, &len, 2); memset(data + 6, 0, 34); memcpy(data + 6, &ver, 2); // hello_ver == 1, impl_ver == 0 - ASSERT_EQ(36, write(sockfd1, data, 36)); + // Write the 36B base starting at data + 4 (NOT data). Pre-Step-1 this + // UT mistakenly wrote `data, 36` which included the leftover "RDMA" + // magic at data[0..4); the server parsed it as msg_len = 0x5244 and + // happened to fall through to NegotiationValid (which then failed on + // hello_ver). Now that Step 1 enforces a HELLO_MSG_LEN_MAX upper bound, + // such an oversized msg_len would be rejected before reaching the + // version check, breaking the intent of this UT. + ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); uint32_t flags = 0; ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd1.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -359,21 +379,23 @@ TEST_F(RdmaTest, client_hello_msg_invalid_version) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data, "RDMA", 4); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); memcpy(data + 4, &len, 2); memset(data + 6, 0, 32); memcpy(data + 8, &ver, 2); // hello_ver == 0, impl_ver == 1 - ASSERT_EQ(36, write(sockfd2, data, 36)); + // See comment above on `write(sockfd1, data + 4, 36)` for why we + // write from data + 4 instead of data. + ASSERT_EQ(36, write(sockfd2, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd2.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -390,11 +412,11 @@ TEST_F(RdmaTest, client_hello_msg_invalid_sq_rq_block_size) { addr.sin_port = htons(PORT); Socket* s = NULL; uint32_t flags = butil::HostToNet32(0); - rdma::HelloMessage msg{}; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 10; msg.rq_size = 16; @@ -406,17 +428,17 @@ TEST_F(RdmaTest, client_hello_msg_invalid_sq_rq_block_size) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd1.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -431,17 +453,17 @@ TEST_F(RdmaTest, client_hello_msg_invalid_sq_rq_block_size) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd2, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd2.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -456,17 +478,17 @@ TEST_F(RdmaTest, client_hello_msg_invalid_sq_rq_block_size) { ASSERT_EQ(0, connect(sockfd3, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd3, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd3, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); ASSERT_EQ(sizeof(flags), write(sockfd3, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); sockfd3.reset(-1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -482,11 +504,11 @@ TEST_F(RdmaTest, client_close_after_qp_build) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - rdma::HelloMessage msg; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -500,10 +522,10 @@ TEST_F(RdmaTest, client_close_after_qp_build) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(40, write(sockfd1, data, 40)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -519,11 +541,11 @@ TEST_F(RdmaTest, client_close_during_ack_send) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - rdma::HelloMessage msg; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -537,17 +559,17 @@ TEST_F(RdmaTest, client_close_during_ack_send) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); uint32_t flags = butil::HostToNet32(1); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -563,11 +585,11 @@ TEST_F(RdmaTest, client_close_after_ack_send) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - rdma::HelloMessage msg; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -581,18 +603,18 @@ TEST_F(RdmaTest, client_close_after_ack_send) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); uint32_t flags = butil::HostToNet32(0); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); - ASSERT_EQ(Socket::RDMA_OFF, s->_rdma_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); close(sockfd1); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -602,17 +624,17 @@ TEST_F(RdmaTest, client_close_after_ack_send) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd2, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); flags = butil::HostToNet32(1); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); close(sockfd2); usleep(100000); // wait for server to handle the msg ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -628,11 +650,11 @@ TEST_F(RdmaTest, client_send_data_on_tcp_after_ack_send) { addr.sin_family = AF_INET; addr.sin_port = htons(PORT); Socket* s = NULL; - rdma::HelloMessage msg; - uint8_t data[RDMA_HELLO_MSG_LEN]; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -646,17 +668,17 @@ TEST_F(RdmaTest, client_send_data_on_tcp_after_ack_send) { ASSERT_EQ(0, connect(sockfd1, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd1, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd1, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); uint32_t flags = butil::HostToNet32(0); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(sizeof(flags), write(sockfd1, &flags, sizeof(flags))); usleep(100000); ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -666,17 +688,17 @@ TEST_F(RdmaTest, client_send_data_on_tcp_after_ack_send) { ASSERT_EQ(0, connect(sockfd2, (sockaddr*)&addr, sizeof(sockaddr))); usleep(100000); // wait for server to handle the msg s = GetSocketFromServer(0); - ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, write(sockfd2, data, 4)); // Write magic string. usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(36, write(sockfd2, data + 4, 36)); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); flags = butil::HostToNet32(1); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); // wait for server to handle the msg - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(sizeof(flags), write(sockfd2, &flags, sizeof(flags))); usleep(100000); ASSERT_EQ(NULL, GetSocketFromServer(0)); @@ -690,7 +712,7 @@ TEST_F(RdmaTest, server_miss_before_hello_send) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -706,7 +728,7 @@ TEST_F(RdmaTest, server_miss_before_hello_send) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); @@ -721,7 +743,7 @@ TEST_F(RdmaTest, server_close_before_hello_send) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -737,15 +759,15 @@ TEST_F(RdmaTest, server_close_before_hello_send) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); close(acc_fd); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EEOF, cntl.ErrorCode()); @@ -757,7 +779,7 @@ TEST_F(RdmaTest, server_miss_during_magic_str) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -773,12 +795,12 @@ TEST_F(RdmaTest, server_miss_during_magic_str) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(2, write(acc_fd, "RD", 2)); usleep(100000); bthread_id_join(cntl.call_id()); @@ -792,7 +814,7 @@ TEST_F(RdmaTest, server_close_during_magic_str) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -808,17 +830,17 @@ TEST_F(RdmaTest, server_close_during_magic_str) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(2, write(acc_fd, "RD", 2)); usleep(100000); close(acc_fd); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EEOF, cntl.ErrorCode()); @@ -830,7 +852,7 @@ TEST_F(RdmaTest, server_hello_invalid_magic_str) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -846,15 +868,15 @@ TEST_F(RdmaTest, server_hello_invalid_magic_str) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(4, write(acc_fd, "ABCD", 4)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EPROTO, cntl.ErrorCode()); @@ -866,7 +888,7 @@ TEST_F(RdmaTest, server_miss_during_hello_msg) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -882,12 +904,12 @@ TEST_F(RdmaTest, server_miss_during_hello_msg) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(4, write(acc_fd, "RDMA", 4)); ASSERT_EQ(2, write(acc_fd, "00", 2)); bthread_id_join(cntl.call_id()); @@ -901,7 +923,7 @@ TEST_F(RdmaTest, server_close_during_hello_msg) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -917,17 +939,17 @@ TEST_F(RdmaTest, server_close_during_hello_msg) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); ASSERT_EQ(4, write(acc_fd, "RDMA", 4)); ASSERT_EQ(2, write(acc_fd, "00", 2)); close(acc_fd); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EEOF, cntl.ErrorCode()); @@ -939,7 +961,7 @@ TEST_F(RdmaTest, server_hello_invalid_msg_len) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -955,19 +977,19 @@ TEST_F(RdmaTest, server_hello_invalid_msg_len) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); memcpy(data, "RDMA", 4); uint16_t len = butil::HostToNet16(35); memcpy(data + 4, &len, 2); memset(data + 6, 0, 32); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FAILED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FAILED, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(EPROTO, cntl.ErrorCode()); @@ -979,7 +1001,7 @@ TEST_F(RdmaTest, server_hello_invalid_version) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -995,19 +1017,19 @@ TEST_F(RdmaTest, server_hello_invalid_version) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); memcpy(data, "RDMA", 4); - uint16_t len = butil::HostToNet16(RDMA_HELLO_MSG_LEN); + uint16_t len = butil::HostToNet16(rdma::v2_wire::HELLO_MSG_LEN_MIN); memcpy(data + 4, &len, 2); memset(data + 6, 0, 32); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, read(acc_fd, data, 4)); uint32_t* tmp = (uint32_t*)data; ASSERT_EQ(0, butil::NetToHost32(*tmp)); @@ -1022,7 +1044,7 @@ TEST_F(RdmaTest, server_hello_invalid_sq_rq_size) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1038,15 +1060,15 @@ TEST_F(RdmaTest, server_hello_invalid_sq_rq_size) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); - rdma::HelloMessage msg; - msg.msg_len = RDMA_HELLO_MSG_LEN; + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; msg.hello_ver = 1; msg.impl_ver = 1; msg.sq_size = 0; @@ -1056,10 +1078,10 @@ TEST_F(RdmaTest, server_hello_invalid_sq_rq_size) { msg.gid = rdma::GetRdmaGid(); memcpy(data, "RDMA", 4); msg.Serialize(data + 4); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, read(acc_fd, data, 4)); uint32_t* tmp = (uint32_t*)data; ASSERT_EQ(0, butil::NetToHost32(*tmp)); @@ -1074,7 +1096,7 @@ TEST_F(RdmaTest, server_miss_after_ack) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1090,17 +1112,17 @@ TEST_F(RdmaTest, server_miss_after_ack) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); - rdma::HelloMessage msg; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -1108,10 +1130,10 @@ TEST_F(RdmaTest, server_miss_after_ack) { msg.gid = rdma::GetRdmaGid(); memcpy(data, "RDMA", 4); msg.Serialize(data + 4); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, read(acc_fd, data, 4)); uint32_t* tmp = (uint32_t*)data; ASSERT_EQ(1, butil::NetToHost32(*tmp)); @@ -1126,7 +1148,7 @@ TEST_F(RdmaTest, server_close_after_ack) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1142,17 +1164,17 @@ TEST_F(RdmaTest, server_close_after_ack) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); - rdma::HelloMessage msg; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -1160,10 +1182,10 @@ TEST_F(RdmaTest, server_close_after_ack) { msg.gid = rdma::GetRdmaGid(); memcpy(data, "RDMA", 4); msg.Serialize(data + 4); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); ASSERT_EQ(4, read(acc_fd, data, 4)); uint32_t* tmp = (uint32_t*)data; ASSERT_EQ(1, butil::NetToHost32(*tmp)); @@ -1179,7 +1201,7 @@ TEST_F(RdmaTest, server_send_data_on_tcp_after_ack) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1195,17 +1217,17 @@ TEST_F(RdmaTest, server_send_data_on_tcp_after_ack) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::C_HELLO_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); ASSERT_TRUE(acc_fd >= 0); - uint8_t data[RDMA_HELLO_MSG_LEN]; - ASSERT_EQ(RDMA_HELLO_MSG_LEN, read(acc_fd, data, RDMA_HELLO_MSG_LEN)); + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); - rdma::HelloMessage msg; - msg.msg_len = RDMA_HELLO_MSG_LEN; - msg.hello_ver = RDMA_HELLO_VERSION; - msg.impl_ver = RDMA_IMPL_VERSION; + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; msg.sq_size = 16; msg.rq_size = 16; msg.block_size = 8192; @@ -1213,23 +1235,528 @@ TEST_F(RdmaTest, server_send_data_on_tcp_after_ack) { msg.gid = rdma::GetRdmaGid(); memcpy(data, "RDMA", 4); msg.Serialize(data + 4); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); usleep(100000); - ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, s->_rdma_ep->_state); - ASSERT_EQ(RDMA_HELLO_MSG_LEN, write(acc_fd, data, RDMA_HELLO_MSG_LEN)); + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); bthread_id_join(cntl.call_id()); ASSERT_EQ(EPROTO, cntl.ErrorCode()); } + +TEST_F(RdmaTest, v2_client_hello_bytes_baseline) { + butil::fd_guard sockfd(butil::tcp_listen(g_ep)); + EXPECT_TRUE(sockfd >= 0); + + Channel channel; + ChannelOptions chan_options; + chan_options.socket_mode = SOCKET_MODE_RDMA; + chan_options.connect_timeout_ms = 500; + chan_options.timeout_ms = 500; + chan_options.max_retry = 0; + ASSERT_EQ(0, channel.Init(g_ep, &chan_options)); + + Controller cntl; + test::EchoRequest req; + test::EchoResponse res; + req.set_message(__FUNCTION__); + google::protobuf::Closure* done = DoNothing(); + ::test::EchoService::Stub(&channel).Echo(&cntl, &req, &res, done); + + usleep(100000); + SocketUniquePtr s; + ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); + + butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); + ASSERT_TRUE(acc_fd >= 0); + + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(acc_fd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); + + // [0..4) magic + ASSERT_EQ(0, memcmp(data, "RDMA", 4)); + // [4..6) msg_len, big-endian uint16 == 40 + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, + (size_t)(((uint16_t)data[4] << 8) | (uint16_t)data[5])); + // [6..8) hello_ver, big-endian uint16 == rdma::RDMA_HELLO_V2_VERSION + ASSERT_EQ(rdma::RDMA_HELLO_V2_VERSION, + (uint16_t)(((uint16_t)data[6] << 8) | (uint16_t)data[7])); + // [8..10) impl_ver, big-endian uint16 == rdma::RDMA_IMPL_V2_VERSION + ASSERT_EQ(rdma::RDMA_IMPL_V2_VERSION, + (uint16_t)(((uint16_t)data[8] << 8) | (uint16_t)data[9])); + + rdma::v2_wire::HelloMessage msg{}; + msg.Deserialize(data + 4); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, msg.msg_len); + ASSERT_EQ(rdma::RDMA_HELLO_V2_VERSION, msg.hello_ver); + ASSERT_EQ(rdma::RDMA_IMPL_V2_VERSION, msg.impl_ver); + + bthread_id_join(cntl.call_id()); +} + +TEST_F(RdmaTest, v2_server_hello_bytes_baseline) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Send a well-formed v2 hello so the server enters S_ACK_WAIT. + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = rdma::v2_wire::HELLO_MSG_LEN_MIN; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; + msg.sq_size = 16; + msg.rq_size = 16; + msg.block_size = 8192; + msg.qp_num = 0; + msg.gid = rdma::GetRdmaGid(); + + uint8_t data[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + memcpy(data, "RDMA", 4); + msg.Serialize(data + 4); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(sockfd, data, rdma::v2_wire::HELLO_MSG_LEN_MIN)); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Read server's reply hello and assert its byte-level layout. + uint8_t reply[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, read(sockfd, reply, rdma::v2_wire::HELLO_MSG_LEN_MIN)); + + ASSERT_EQ(0, memcmp(reply, "RDMA", 4)); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, + (size_t)(((uint16_t)reply[4] << 8) | (uint16_t)reply[5])); + ASSERT_EQ(rdma::RDMA_HELLO_V2_VERSION, + (uint16_t)(((uint16_t)reply[6] << 8) | (uint16_t)reply[7])); + ASSERT_EQ(rdma::RDMA_IMPL_V2_VERSION, + (uint16_t)(((uint16_t)reply[8] << 8) | (uint16_t)reply[9])); + + rdma::v2_wire::HelloMessage reply_msg{}; + reply_msg.Deserialize(reply + 4); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, reply_msg.msg_len); + ASSERT_EQ(rdma::RDMA_HELLO_V2_VERSION, reply_msg.hello_ver); + ASSERT_EQ(rdma::RDMA_IMPL_V2_VERSION, reply_msg.impl_ver); + + // Drive the server into FALLBACK_TCP via ACK flags=0 so the test ends + // cleanly without requiring real RDMA hardware. + uint32_t flags = butil::HostToNet32(0); + ASSERT_EQ(sizeof(flags), write(sockfd, &flags, sizeof(flags))); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); + + sockfd.reset(-1); + usleep(100000); + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + StopServer(); +} + +TEST_F(RdmaTest, v2_server_drains_tail_then_reads_ack) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Build a v2 hello with msg_len = 48 (40 base + 8B zero tail). + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = 48; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; + msg.sq_size = 16; + msg.rq_size = 16; + msg.block_size = 8192; + msg.qp_num = 0; + msg.gid = rdma::GetRdmaGid(); + + uint8_t buf[48]; + memcpy(buf, "RDMA", 4); + msg.Serialize(buf + 4); + memset(buf + 40, 0x00, 8); // 8B zero tail + ASSERT_EQ(48, write(sockfd, buf, 48)); + usleep(100000); + + // Send the real ACK (flags=1 = ACK_MSG_RDMA_OK). + uint32_t flags = butil::HostToNet32(1); + ASSERT_EQ(sizeof(flags), write(sockfd, &flags, sizeof(flags))); + usleep(100000); + + ASSERT_EQ(rdma::RdmaEndpoint::ESTABLISHED, static_cast(s->_transport.get())->_rdma_ep->_state); + + sockfd.reset(-1); + usleep(100000); + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + StopServer(); +} + +TEST_F(RdmaTest, v2_server_rejects_oversized_msg_len) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Build a v2 hello with msg_len = 4097 (HELLO_MSG_LEN_MAX + 1). + // We only send the 40B base; the server must reject before reading + // (and definitely before attempting to drain) any "tail". + rdma::v2_wire::HelloMessage msg{}; + msg.msg_len = 4097; + msg.hello_ver = rdma::RDMA_HELLO_V2_VERSION; + msg.impl_ver = rdma::RDMA_IMPL_V2_VERSION; + msg.sq_size = 16; + msg.rq_size = 16; + msg.block_size = 8192; + msg.qp_num = 0; + msg.gid = rdma::GetRdmaGid(); + + uint8_t buf[rdma::v2_wire::HELLO_MSG_LEN_MIN]; + memcpy(buf, "RDMA", 4); + msg.Serialize(buf + 4); + ASSERT_EQ(rdma::v2_wire::HELLO_MSG_LEN_MIN, write(sockfd, buf, rdma::v2_wire::HELLO_MSG_LEN_MIN)); + usleep(100000); + + + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + sockfd.reset(-1); + usleep(100000); + + StopServer(); +} + +// RAII for FLAGS_rdma_client_handshake_version: lets us flip the +// client-side handshake version for a single test and restore it on +// scope exit so subsequent tests stay on the v2 default. +class HandshakeVersionFlag { +public: + explicit HandshakeVersionFlag(int v) + : _saved(rdma::FLAGS_rdma_client_handshake_version) { + rdma::FLAGS_rdma_client_handshake_version = v; + } + ~HandshakeVersionFlag() { + rdma::FLAGS_rdma_client_handshake_version = _saved; + } +private: + int _saved; +}; + +// Build a v3 wire packet from an RdmaHello: "RDM3" + pb_size_be + body. +std::string MakeV3Packet(const rdma::RdmaHello& msg) { + std::string body; + EXPECT_TRUE(msg.SerializeToString(&body)); + std::string packet; + packet.reserve(4 + 4 + body.size()); + packet.append("RDM3", 4); + uint32_t pb_size_be = + butil::HostToNet32(static_cast(body.size())); + packet.append(reinterpret_cast(&pb_size_be), 4); + packet.append(body); + return packet; +} + +// Build a fully-valid RdmaHello: all 6 required fields are set, with +// values that pass RdmaHelloV3Wire::RdmaHelloValid(). +// - block_size = 8192 (>= MIN_BLOCK_SIZE) +// - sq_size / rq_size = 16 (>= MIN_QP_SIZE) +// - gid = exactly 16B (sizeof(ibv_gid)) +// - qp_num = 0 (allowed because g_skip_rdma_init in UT) +rdma::RdmaHello MakeValidV3Hello() { + rdma::RdmaHello msg; + msg.set_block_size(8192); + msg.set_sq_size(16); + msg.set_rq_size(16); + msg.set_lid(0); + ibv_gid gid = rdma::GetRdmaGid(); + msg.set_gid(std::string(reinterpret_cast(gid.raw), + sizeof(gid.raw))); + msg.set_qp_num(0); + return msg; +} + + +TEST_F(RdmaTest, v3_client_hello_bytes_baseline) { + HandshakeVersionFlag _hsv(3); + + butil::fd_guard sockfd(butil::tcp_listen(g_ep)); + EXPECT_TRUE(sockfd >= 0); + + Channel channel; + ChannelOptions chan_options; + chan_options.socket_mode = SOCKET_MODE_RDMA; + chan_options.connect_timeout_ms = 500; + chan_options.timeout_ms = 500; + chan_options.max_retry = 0; + ASSERT_EQ(0, channel.Init(g_ep, &chan_options)); + + Controller cntl; + test::EchoRequest req; + test::EchoResponse res; + req.set_message(__FUNCTION__); + google::protobuf::Closure* done = DoNothing(); + ::test::EchoService::Stub(&channel).Echo(&cntl, &req, &res, done); + + butil::fd_guard acc_fd(accept(sockfd, NULL, NULL)); + ASSERT_TRUE(acc_fd >= 0); + + // [0..4) magic "RDM3" + uint8_t magic[4]; + ASSERT_EQ(4, read(acc_fd, magic, 4)); + ASSERT_EQ(0, memcmp(magic, "RDM3", 4)); + + // [4..8) pb_size, big-endian uint32, must be in (0, 4096] + uint8_t size_buf[4]; + ASSERT_EQ(4, read(acc_fd, size_buf, 4)); + uint32_t pb_size = + butil::NetToHost32(*reinterpret_cast(size_buf)); + ASSERT_GT(pb_size, 0u); + ASSERT_LE(pb_size, 4096u); + + // [8..8+pb_size) RdmaHello protobuf body. + std::string body(pb_size, '\0'); + ASSERT_EQ((ssize_t)pb_size, read(acc_fd, &body[0], pb_size)); + rdma::RdmaHello msg; + ASSERT_TRUE(msg.ParseFromString(body)); + + // All 6 required fields must be present (ParseFromString would + // have already returned false otherwise). + ASSERT_TRUE(msg.has_block_size()); + ASSERT_TRUE(msg.has_sq_size()); + ASSERT_TRUE(msg.has_rq_size()); + ASSERT_TRUE(msg.has_lid()); + ASSERT_TRUE(msg.has_gid()); + ASSERT_TRUE(msg.has_qp_num()); + // gid wire encoding must be exactly 16 bytes (sizeof(ibv_gid)). + ASSERT_EQ(sizeof(ibv_gid), msg.gid().size()); + + // Let the RPC time out and release resources. + bthread_id_join(cntl.call_id()); +} + +TEST_F(RdmaTest, v3_server_hello_bytes_baseline) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_EQ(rdma::RdmaEndpoint::UNINIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Send a valid v3 hello. + std::string packet = MakeV3Packet(MakeValidV3Hello()); + ASSERT_EQ((ssize_t)packet.size(), + write(sockfd, packet.data(), packet.size())); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + + // Read server's reply hello: 4B magic + 4B pb_size + body. + uint8_t reply_magic[4]; + ASSERT_EQ(4, read(sockfd, reply_magic, 4)); + ASSERT_EQ(0, memcmp(reply_magic, "RDM3", 4)); + + uint8_t size_buf[4]; + ASSERT_EQ(4, read(sockfd, size_buf, 4)); + uint32_t pb_size = + butil::NetToHost32(*reinterpret_cast(size_buf)); + ASSERT_GT(pb_size, 0u); + ASSERT_LE(pb_size, 4096u); + + std::string body(pb_size, '\0'); + ASSERT_EQ((ssize_t)pb_size, read(sockfd, &body[0], pb_size)); + rdma::RdmaHello reply; + ASSERT_TRUE(reply.ParseFromString(body)); + ASSERT_TRUE(reply.has_block_size()); + ASSERT_TRUE(reply.has_sq_size()); + ASSERT_TRUE(reply.has_rq_size()); + ASSERT_TRUE(reply.has_gid()); + ASSERT_EQ(sizeof(ibv_gid), reply.gid().size()); + + // Drive the server into FALLBACK_TCP via ACK flags=0 so the test ends + // cleanly without requiring real RDMA hardware. + uint32_t flags = butil::HostToNet32(0); + ASSERT_EQ((ssize_t)sizeof(flags), + write(sockfd, &flags, sizeof(flags))); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); + + sockfd.reset(-1); + usleep(100000); + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + StopServer(); +} + +TEST_F(RdmaTest, v3_server_rejects_zero_pb_size) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + + // "RDM3" + pb_size = 0 (4B big-endian zero). + uint8_t buf[8] = {'R', 'D', 'M', '3', 0, 0, 0, 0}; + ASSERT_EQ(8, write(sockfd, buf, 8)); + usleep(100000); + + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + sockfd.reset(-1); + StopServer(); +} + +TEST_F(RdmaTest, v3_server_rejects_oversized_pb_size) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + + uint8_t buf[8]; + memcpy(buf, "RDM3", 4); + uint32_t pb_size_be = butil::HostToNet32(4097); + memcpy(buf + 4, &pb_size_be, 4); + ASSERT_EQ(8, write(sockfd, buf, 8)); + usleep(100000); + + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + sockfd.reset(-1); + StopServer(); +} + +TEST_F(RdmaTest, v3_server_rejects_invalid_pb_bytes) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + + // "RDM3" + pb_size = 8 + 8 bytes of 0xff (invalid protobuf body). + uint8_t buf[16]; + memcpy(buf, "RDM3", 4); + uint32_t pb_size_be = butil::HostToNet32(8); + memcpy(buf + 4, &pb_size_be, 4); + memset(buf + 8, 0xff, 8); + ASSERT_EQ(16, write(sockfd, buf, 16)); + usleep(100000); + + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + sockfd.reset(-1); + StopServer(); +} + +TEST_F(RdmaTest, v3_server_invalid_sq_size_falls_back) { + StartServer(); + + sockaddr_in addr; + bzero((char*)&addr, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(PORT); + butil::fd_guard sockfd(socket(AF_INET, SOCK_STREAM, 0)); + ASSERT_TRUE(sockfd >= 0); + ASSERT_EQ(0, connect(sockfd, (sockaddr*)&addr, sizeof(sockaddr))); + usleep(100000); + Socket* s = GetSocketFromServer(0); + ASSERT_TRUE(s != NULL); + + rdma::RdmaHello msg = MakeValidV3Hello(); + msg.set_sq_size(0); // invalid: < MIN_QP_SIZE (16) + std::string packet = MakeV3Packet(msg); + ASSERT_EQ((ssize_t)packet.size(), + write(sockfd, packet.data(), packet.size())); + usleep(100000); + + // Server validated the hello as invalid -> _rdma_state = RDMA_OFF, + // but still proceeds to S_ACK_WAIT (sends its own reply hello). + ASSERT_EQ(rdma::RdmaEndpoint::S_ACK_WAIT, static_cast(s->_transport.get())->_rdma_ep->_state); + ASSERT_EQ(RdmaTransport::RDMA_OFF, static_cast(s->_transport.get())->_rdma_state); + + // Drain server's reply hello (content not asserted here; covered + // by v3_server_hello_bytes_baseline). + uint8_t reply_hdr[8]; + ASSERT_EQ(8, read(sockfd, reply_hdr, 8)); + ASSERT_EQ(0, memcmp(reply_hdr, "RDM3", 4)); + uint32_t reply_pb_size = butil::NetToHost32( + *reinterpret_cast(reply_hdr + 4)); + std::string reply_body(reply_pb_size, '\0'); + ASSERT_EQ((ssize_t)reply_pb_size, + read(sockfd, &reply_body[0], reply_pb_size)); + + // Client ACK flags=0 -> server settles into FALLBACK_TCP. + uint32_t flags = butil::HostToNet32(0); + ASSERT_EQ((ssize_t)sizeof(flags), + write(sockfd, &flags, sizeof(flags))); + usleep(100000); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); + + sockfd.reset(-1); + usleep(100000); + ASSERT_EQ(NULL, GetSocketFromServer(0)); + + StopServer(); +} + TEST_F(RdmaTest, try_global_disable_rdma) { StartServer(); rdma::g_rdma_available.store(false, butil::memory_order_relaxed); Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1245,7 +1772,7 @@ TEST_F(RdmaTest, try_global_disable_rdma) { usleep(100000); SocketUniquePtr s; ASSERT_EQ(0, Socket::Address(cntl._single_server_id, &s)); - ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, s->_rdma_ep->_state); + ASSERT_EQ(rdma::RdmaEndpoint::FALLBACK_TCP, static_cast(s->_transport.get())->_rdma_ep->_state); bthread_id_join(cntl.call_id()); ASSERT_EQ(0, cntl.ErrorCode()); @@ -1256,7 +1783,7 @@ TEST_F(RdmaTest, try_global_disable_rdma) { TEST_F(RdmaTest, server_option_invalid) { Server server; ServerOptions options; - options.use_rdma = true; + options.socket_mode = SOCKET_MODE_RDMA; // rtmp and rdma are incompatible options.rtmp_service = (RtmpService*)1; @@ -1281,7 +1808,7 @@ TEST_F(RdmaTest, server_option_invalid) { TEST_F(RdmaTest, channel_option_invalid) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; // rtmp and rdma are incompatible chan_options.protocol = "rtmp"; @@ -1342,7 +1869,7 @@ TEST_F(RdmaTest, channel_option_invalid) { ASSERT_EQ(-1, channel.Init(g_ep, &chan_options)); } -TEST_F(RdmaTest, rdma_client_to_rdma_server) { +TEST_P(RdmaRpcTest, rdma_client_to_rdma_server) { if (!FLAGS_rdma_test_enable) { return; } @@ -1351,7 +1878,7 @@ TEST_F(RdmaTest, rdma_client_to_rdma_server) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1362,14 +1889,14 @@ TEST_F(RdmaTest, rdma_client_to_rdma_server) { req.set_message(__FUNCTION__); google::protobuf::Closure* done = DoNothing(); ::test::EchoService::Stub(&channel).Echo(&cntl, &req, &res, done); - usleep(100000); + // usleep(100000); bthread_id_join(cntl.call_id()); ASSERT_EQ(0, cntl.ErrorCode()); StopServer(); } -TEST_F(RdmaTest, tcp_client_to_tcp_server) { +TEST_P(RdmaRpcTest, tcp_client_to_tcp_server) { StartServer(false); Channel channel; @@ -1391,7 +1918,7 @@ TEST_F(RdmaTest, tcp_client_to_tcp_server) { StopServer(); } -TEST_F(RdmaTest, tcp_client_to_rdma_server) { +TEST_P(RdmaRpcTest, tcp_client_to_rdma_server) { StartServer(); Channel channel; @@ -1413,12 +1940,12 @@ TEST_F(RdmaTest, tcp_client_to_rdma_server) { StopServer(); } -TEST_F(RdmaTest, rdma_client_to_tcp_server) { +TEST_P(RdmaRpcTest, rdma_client_to_tcp_server) { StartServer(false); Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1440,12 +1967,12 @@ static const int RPC_NUM = 1024; void DumpRdmaEndpointInfo(Socket* client, Socket* server) { std::cout << std::endl << "client:"; - client->_rdma_ep->DebugInfo(std::cout); + static_cast(client->_transport.get())->_rdma_ep->DebugInfo(std::cout); std::cout << std::endl << "server:"; - server->_rdma_ep->DebugInfo(std::cout); + static_cast(server->_transport.get())->_rdma_ep->DebugInfo(std::cout); } -TEST_F(RdmaTest, send_rpcs_in_one_qp) { +TEST_P(RdmaRpcTest, send_rpcs_in_one_qp) { if (!FLAGS_rdma_test_enable) { return; } @@ -1454,9 +1981,9 @@ TEST_F(RdmaTest, send_rpcs_in_one_qp) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; - chan_options.timeout_ms = 5000; + chan_options.timeout_ms = 50000; chan_options.max_retry = 0; ASSERT_EQ(0, channel.Init(g_ep, &chan_options)); Controller cntl[RPC_NUM]; @@ -1516,50 +2043,57 @@ TEST_F(RdmaTest, send_rpcs_in_one_qp) { Socket* m = GetSocketFromServer(0); DumpRdmaEndpointInfo(s.get(), m); } - ASSERT_TRUE(0 == cntl[i].ErrorCode() || EOVERCROWDED == cntl[i].ErrorCode()) - << "req[" << i << "] " << berror(cntl[i].ErrorCode()); + ASSERT_TRUE(0 == cntl[i].ErrorCode() || + EOVERCROWDED == cntl[i].ErrorCode()) << "req[" << i << "] " << berror(cntl[i].ErrorCode()); } + SocketUniquePtr s; + ASSERT_EQ(0, Socket::Address(cntl[0]._single_server_id, &s)); + Socket* m = GetSocketFromServer(0); + DumpRdmaEndpointInfo(s.get(), m); + StopServer(); } -TEST_F(RdmaTest, send_rpc_in_many_qp) { +TEST_P(RdmaRpcTest, send_rpc_in_many_qp) { if (!FLAGS_rdma_test_enable) { return; } + butil::ip_t ip; + ASSERT_EQ(0, butil::str2ip(g_ip.c_str(), &ip)); + Server server[100]; MyEchoService svc[100]; int num = 100; + butil::EndPoint server_eps[100]; for (int i = 0; i < num; ++i) { ServerOptions options; - options.use_rdma = true; + options.socket_mode = SOCKET_MODE_RDMA; options.idle_timeout_sec = 1; options.max_concurrency = 0; options.internal_port = -1; server[i].AddService(&svc[i], SERVER_DOESNT_OWN_SERVICE); - EXPECT_EQ(0, server[i].Start(i + 8000, &options)); + ASSERT_EQ(0, server[i].Start(0, &options)); + server_eps[i] = butil::EndPoint(ip, server[i].listen_address().port); } int port = 0; butil::IOBuf attach; attach.resize(4096); ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; - chan_options.timeout_ms = 500; + chan_options.timeout_ms = 100000; chan_options.max_retry = 0; Channel channel[RPC_NUM]; Server* svr[RPC_NUM]; Controller cntl[RPC_NUM]; test::EchoRequest req[RPC_NUM]; test::EchoResponse res[RPC_NUM]; - butil::ip_t ip; - butil::str2ip(g_ip.c_str(), &ip); for (int i = 0; i < RPC_NUM; ++i) { svr[i] = &server[i % num]; - butil::EndPoint ep(ip, 8000 + ((port++) % num)); - ASSERT_EQ(0, channel[i].Init(ep, &chan_options)); + ASSERT_EQ(0, channel[i].Init(server_eps[(port++) % num], &chan_options)); req[i].set_message(__FUNCTION__); cntl[i].request_attachment().append(attach); google::protobuf::Closure* done = DoNothing(); @@ -1569,16 +2103,19 @@ TEST_F(RdmaTest, send_rpc_in_many_qp) { bthread_id_join(cntl[i].call_id()); if (cntl[i].ErrorCode() == ERPCTIMEDOUT) { SocketUniquePtr s; - ASSERT_EQ(0, Socket::Address(cntl[i]._single_server_id, &s)); - std::vector sids; - svr[i]->_am->ListConnections(&sids); - for (size_t i = 0; i < sids.size(); ++i) { - SocketUniquePtr m; - ASSERT_EQ(0, Socket::AddressFailedAsWell(sids[i], &m)); - DumpRdmaEndpointInfo(s.get(), m.get()); + EXPECT_EQ(0, Socket::Address(cntl[i]._single_server_id, &s)); + if (s && svr[i] && svr[i]->_am) { + std::vector sids; + svr[i]->_am->ListConnections(&sids); + for (size_t j = 0; j < sids.size(); ++j) { + SocketUniquePtr m; + if (Socket::AddressFailedAsWell(sids[j], &m) == 0) { + DumpRdmaEndpointInfo(s.get(), m.get()); + } + } } } - ASSERT_EQ(0, cntl[i].ErrorCode()) << "req[" << i << "]"; + EXPECT_EQ(0, cntl[i].ErrorCode()) << "req[" << i << "]"; } for (int i = 0; i < num; ++i) { @@ -1587,7 +2124,7 @@ TEST_F(RdmaTest, send_rpc_in_many_qp) { } } -TEST_F(RdmaTest, send_rpcs_as_pooled_connection) { +TEST_P(RdmaRpcTest, send_rpcs_as_pooled_connection) { if (!FLAGS_rdma_test_enable) { return; } @@ -1596,7 +2133,7 @@ TEST_F(RdmaTest, send_rpcs_as_pooled_connection) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 30000; // it may very slow chan_options.timeout_ms = 30000; chan_options.max_retry = 0; @@ -1628,7 +2165,7 @@ TEST_F(RdmaTest, send_rpcs_as_pooled_connection) { StopServer(); } -TEST_F(RdmaTest, send_rpcs_as_short_connection) { +TEST_P(RdmaRpcTest, send_rpcs_as_short_connection) { if (!FLAGS_rdma_test_enable) { return; } @@ -1637,7 +2174,7 @@ TEST_F(RdmaTest, send_rpcs_as_short_connection) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 30000; // it may very slow chan_options.timeout_ms = 30000; chan_options.max_retry = 0; @@ -1669,7 +2206,7 @@ TEST_F(RdmaTest, send_rpcs_as_short_connection) { StopServer(); } -TEST_F(RdmaTest, server_stop_during_rpc) { +TEST_P(RdmaRpcTest, server_stop_during_rpc) { if (!FLAGS_rdma_test_enable) { return; } @@ -1678,7 +2215,7 @@ TEST_F(RdmaTest, server_stop_during_rpc) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 3000; chan_options.max_retry = 0; @@ -1707,7 +2244,7 @@ TEST_F(RdmaTest, server_stop_during_rpc) { } } -TEST_F(RdmaTest, server_close_during_rpc) { +TEST_P(RdmaRpcTest, server_close_during_rpc) { if (!FLAGS_rdma_test_enable) { return; } @@ -1716,7 +2253,7 @@ TEST_F(RdmaTest, server_close_during_rpc) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 3000; chan_options.max_retry = 0; @@ -1749,7 +2286,7 @@ TEST_F(RdmaTest, server_close_during_rpc) { StopServer(); } -TEST_F(RdmaTest, client_close_during_rpc) { +TEST_P(RdmaRpcTest, client_close_during_rpc) { if (!FLAGS_rdma_test_enable) { return; } @@ -1758,7 +2295,7 @@ TEST_F(RdmaTest, client_close_during_rpc) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 3000; chan_options.max_retry = 0; @@ -1789,7 +2326,7 @@ TEST_F(RdmaTest, client_close_during_rpc) { StopServer(); } -TEST_F(RdmaTest, verbs_error_handling) { +TEST_P(RdmaRpcTest, verbs_error_handling) { if (!FLAGS_rdma_test_enable) { return; } @@ -1798,7 +2335,7 @@ TEST_F(RdmaTest, verbs_error_handling) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1826,7 +2363,8 @@ TEST_F(RdmaTest, verbs_error_handling) { wr.sg_list = &sge; wr.num_sge = 1; ibv_send_wr* bad = NULL; - ibv_post_send(s->_rdma_ep->_resource->qp, &wr, &bad); + auto rdma_transport = static_cast(s->_transport.get()); + ibv_post_send(rdma_transport->_rdma_ep->_resource->qp, &wr, &bad); bthread_id_join(cntl.call_id()); ASSERT_EQ(ERDMA, cntl.ErrorCode()); free(buf); @@ -1834,7 +2372,7 @@ TEST_F(RdmaTest, verbs_error_handling) { StopServer(); } -TEST_F(RdmaTest, rdma_use_parallel_channel) { +TEST_P(RdmaRpcTest, rdma_use_parallel_channel) { if (!FLAGS_rdma_test_enable) { return; } @@ -1845,13 +2383,14 @@ TEST_F(RdmaTest, rdma_use_parallel_channel) { Channel subchans[NCHANS]; ParallelChannel channel; ChannelOptions opts; - opts.use_rdma = true; + opts.socket_mode = SOCKET_MODE_RDMA; for (size_t i = 0; i < NCHANS; ++i) { ASSERT_EQ(0, subchans[i].Init(_naming_url.c_str(), "rR", &opts)); ASSERT_EQ(0, channel.AddChannel( &subchans[i], DOESNT_OWN_CHANNEL, NULL, NULL)); } + ASSERT_EQ(0, channel.Init(NULL)); Controller cntl; test::EchoRequest req; @@ -1865,7 +2404,7 @@ TEST_F(RdmaTest, rdma_use_parallel_channel) { StopServer(); } -TEST_F(RdmaTest, rdma_use_selective_channel) { +TEST_P(RdmaRpcTest, rdma_use_selective_channel) { if (!FLAGS_rdma_test_enable) { return; } @@ -1875,7 +2414,7 @@ TEST_F(RdmaTest, rdma_use_selective_channel) { const size_t NCHANS = 8; SelectiveChannel channel; ChannelOptions opts; - opts.use_rdma = true; + opts.socket_mode = SOCKET_MODE_RDMA; ASSERT_EQ(0, channel.Init("rr", &opts)); for (size_t i = 0; i < NCHANS; ++i) { Channel* subchan = new Channel; @@ -1897,7 +2436,7 @@ TEST_F(RdmaTest, rdma_use_selective_channel) { static void MockFree(void* buf) { } -TEST_F(RdmaTest, send_rpcs_with_user_defined_iobuf) { +TEST_P(RdmaRpcTest, send_rpcs_with_user_defined_iobuf) { if (!FLAGS_rdma_test_enable) { return; } @@ -1906,7 +2445,7 @@ TEST_F(RdmaTest, send_rpcs_with_user_defined_iobuf) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 500; chan_options.max_retry = 0; @@ -1961,7 +2500,7 @@ TEST_F(RdmaTest, send_rpcs_with_user_defined_iobuf) { StopServer(); } -TEST_F(RdmaTest, try_memory_pool_empty) { +TEST_P(RdmaRpcTest, try_memory_pool_empty) { if (!FLAGS_rdma_test_enable) { return; } @@ -1970,7 +2509,7 @@ TEST_F(RdmaTest, try_memory_pool_empty) { Channel channel; ChannelOptions chan_options; - chan_options.use_rdma = true; + chan_options.socket_mode = SOCKET_MODE_RDMA; chan_options.connect_timeout_ms = 500; chan_options.timeout_ms = 60000; chan_options.max_retry = 0; @@ -2000,6 +2539,19 @@ TEST_F(RdmaTest, try_memory_pool_empty) { StopServer(); } +// Run every TEST_P(RdmaRpcTest, ...) above twice: once with the +// client-side handshake forced to v2 ("RDMA" magic + fixed-layout +// HelloMessage), once with v3 ("RDM3" magic + protobuf RdmaHello). +// The server always accepts both via magic-byte dispatch, so this +// proves the upper-layer RPC paths behave identically under either +// wire format. +INSTANTIATE_TEST_SUITE_P( + HandshakeVersion, RdmaRpcTest, + ::testing::Values(2, 3), + [](const ::testing::TestParamInfo& info) { + return std::string("v") + std::to_string(info.param); + }); + #endif // if BRPC_WITH_RDMA int main(int argc, char* argv[]) { diff --git a/test/bvar_percentile_unittest.cpp b/test/bvar_percentile_unittest.cpp index f647e272ba..d9d01846a1 100644 --- a/test/bvar_percentile_unittest.cpp +++ b/test/bvar_percentile_unittest.cpp @@ -28,6 +28,7 @@ class PercentileTest : public testing::Test { void TearDown() {} }; +#if !WITH_BABYLON_COUNTER TEST_F(PercentileTest, add) { bvar::detail::Percentile p; for (int j = 0; j < 10; ++j) { @@ -51,6 +52,7 @@ TEST_F(PercentileTest, add) { b.describe(out); } } +#endif // !WITH_BABYLON_COUNTER TEST_F(PercentileTest, merge1) { // Merge 2 PercentileIntervals b1 and b2. b2 has double SAMPLE_SIZE