Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 118 additions & 20 deletions src/brpc/rdma/rdma_endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,16 @@ static const size_t RESERVED_WR_NUM = 3;
// block size (4B)
// sq size (2B)
// rq size (2B)
// lid size (2B)
// GID (16B)
// QP number (4B)
// mtu type (2B)
static const char* MAGIC_STR = "RDMA";
static const size_t MAGIC_STR_LEN = 4;
static const size_t HELLO_MSG_LEN_MIN = 40;
// static const size_t HELLO_MSG_LEN_MAX = 4096;
static const size_t ACK_MSG_LEN = 4;
static uint16_t g_rdma_hello_msg_len = 40; // In Byte
static uint16_t g_rdma_hello_msg_len = 42; // In Byte
static uint16_t g_rdma_hello_version = 2;
Comment on lines +93 to 94
static uint16_t g_rdma_impl_version = 1;
static uint32_t g_rdma_recv_block_size = 0;
Expand All @@ -105,10 +107,16 @@ static const uint32_t ACK_MSG_RDMA_OK = 0x1;
static butil::Mutex* g_rdma_resource_mutex = NULL;
static RdmaResource* g_rdma_resource_list = NULL;

// The HelloMessage should have all base fields, and the new versions of HelloMessage
// maybe have some extern fields.

struct HelloMessage {
void Serialize(void* data) const;
void Deserialize(void* data);
void BaseSerialize(void* data) const;
void ExtSerialize(void* data) const;
void BaseDeserialize(void* data);
uint16_t ExtDeserialize(void* data, uint16_t ext_len);

// base fields
uint16_t msg_len;
uint16_t hello_ver;
uint16_t impl_ver;
Expand All @@ -118,9 +126,12 @@ struct HelloMessage {
uint16_t lid;
ibv_gid gid;
uint32_t qp_num;

// extern fields
uint16_t mtu_type;
};

void HelloMessage::Serialize(void* data) const {
void HelloMessage::BaseSerialize(void* data) const {
uint16_t* current_pos = (uint16_t*)data;
*(current_pos++) = butil::HostToNet16(msg_len);
*(current_pos++) = butil::HostToNet16(hello_ver);
Expand All @@ -132,11 +143,17 @@ void HelloMessage::Serialize(void* data) const {
*(current_pos++) = butil::HostToNet16(rq_size);
*(current_pos++) = butil::HostToNet16(lid);
memcpy(current_pos, gid.raw, 16);
uint32_t* qp_num_pos = (uint32_t*)((char*)current_pos + 16);
current_pos += 8;
uint32_t* qp_num_pos = (uint32_t*)(current_pos);
*qp_num_pos = butil::HostToNet32(qp_num);
}

void HelloMessage::Deserialize(void* data) {
void HelloMessage::ExtSerialize(void* data) const {
uint16_t* current_pos = (uint16_t*)data;
*(current_pos) = butil::HostToNet16(mtu_type);
}

void HelloMessage::BaseDeserialize(void* data) {
uint16_t* current_pos = (uint16_t*)data;
msg_len = butil::NetToHost16(*current_pos++);
hello_ver = butil::NetToHost16(*current_pos++);
Expand All @@ -147,7 +164,26 @@ void HelloMessage::Deserialize(void* data) {
rq_size = butil::NetToHost16(*current_pos++);
lid = butil::NetToHost16(*current_pos++);
memcpy(gid.raw, current_pos, 16);
qp_num = butil::NetToHost32(*(uint32_t*)((char*)current_pos + 16));
current_pos += 8;
qp_num = butil::NetToHost32(*(uint32_t*)(current_pos));
}

uint16_t HelloMessage::ExtDeserialize(void* data, uint16_t ext_len) {
if (ext_len == 0) {
return 0;
}

uint16_t remain_ext_len = ext_len;

// try to deserialize mtu_type
if (remain_ext_len < 2) {
LOG(FATAL) << "illegal HelloMessage, remain ext len is " << remain_ext_len << ", should not be less than 2!!!";
}
uint16_t* current_pos = (uint16_t*)data;
mtu_type = butil::NetToHost16(*current_pos++);
remain_ext_len -= 2;

return remain_ext_len;
}

RdmaResource::~RdmaResource() {
Expand Down Expand Up @@ -435,6 +471,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
<< "Start handshake on " << s->_local_side;

uint8_t data[g_rdma_hello_msg_len];
uint16_t local_mtu_type = GetLocalMtuType();

// First initialize CQ and QP resources
ep->_state = C_ALLOC_QPCQ;
Expand Down Expand Up @@ -463,8 +500,10 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
// Only happens in UT
local_msg.qp_num = 0;
}
local_msg.mtu_type = local_mtu_type;
memcpy(data, MAGIC_STR, 4);
local_msg.Serialize((char*)data + 4);
local_msg.BaseSerialize((char*)data + 4);
local_msg.ExtSerialize((char*)data + HELLO_MSG_LEN_MIN);
if (ep->WriteToFd(data, g_rdma_hello_msg_len) < 0) {
const int saved_errno = errno;
PLOG(WARNING) << "Fail to send hello message to server:" << s->description();
Expand Down Expand Up @@ -502,7 +541,7 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
return NULL;
}
HelloMessage remote_msg;
remote_msg.Deserialize(data);
remote_msg.BaseDeserialize(data);
if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) {
LOG(WARNING) << "Fail to parse Hello Message length from server:"
<< s->description();
Expand All @@ -512,9 +551,27 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
return NULL;
}

// In older versions of brpc, IBV_MTU_1024 is the default mtu type,
// So we set remote_mtu IBV_MTU_1024 at default to be ompatible with older versions.
uint16_t remote_mtu_type = IBV_MTU_1024;
if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) {
// TODO: Read Hello Message customized data
// Just for future use, should not happen now
// Read Hello Message customized data
uint16_t remote_msg_ext_len = remote_msg.msg_len - HELLO_MSG_LEN_MIN;
uint8_t ext_data[remote_msg_ext_len];
if (ep->ReadFromFd(ext_data, remote_msg_ext_len) < 0) {
const int saved_errno = errno;
PLOG(WARNING) << "Fail to get Hello Message ext fields from server:" << s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
s->description().c_str(), berror(saved_errno));
ep->_state = FAILED;
return NULL;
}
remote_msg.ExtDeserialize(ext_data, remote_msg_ext_len);
if (remote_msg_ext_len >= 2) {
// mtu_type field is valid
remote_mtu_type = remote_msg.mtu_type;
}
// TODO: other extern fields
}

if (!HelloNegotiationValid(remote_msg)) {
Expand All @@ -534,7 +591,9 @@ void* RdmaEndpoint::ProcessHandshakeAtClient(void* arg) {
ep->_local_window_capacity, butil::memory_order_relaxed);

ep->_state = C_BRINGUP_QP;
if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) {
// use the minimum of local mtu type and remote mtu type
uint16_t min_mtu_type = std::min(local_mtu_type, remote_mtu_type);
if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num, min_mtu_type) < 0) {
LOG(WARNING) << "Fail to bringup QP, fallback to tcp:" << s->description();
Comment on lines +594 to 597
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
} else {
Expand Down Expand Up @@ -582,6 +641,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
<< "Start handshake on " << s->description();

uint8_t data[g_rdma_hello_msg_len];
uint16_t local_mtu_type = GetLocalMtuType();

ep->_state = S_HELLO_WAIT;
if (ep->ReadFromFd(data, MAGIC_STR_LEN) < 0) {
Expand All @@ -605,7 +665,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
return NULL;
}

if (ep->ReadFromFd(data, g_rdma_hello_msg_len - MAGIC_STR_LEN) < 0) {
if (ep->ReadFromFd(data, HELLO_MSG_LEN_MIN - MAGIC_STR_LEN) < 0) {
const int saved_errno = errno;
PLOG(WARNING) << "Fail to read Hello Message from client:" << s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
Expand All @@ -615,7 +675,7 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
}

HelloMessage remote_msg;
remote_msg.Deserialize(data);
remote_msg.BaseDeserialize(data);
if (remote_msg.msg_len < HELLO_MSG_LEN_MIN) {
LOG(WARNING) << "Fail to parse Hello Message length from client:"
<< s->description();
Expand All @@ -624,9 +684,28 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
ep->_state = FAILED;
return NULL;
}

// In older versions of brpc, IBV_MTU_1024 is the default mtu type,
// So we set remote_mtu IBV_MTU_1024 at default to be ompatible with older versions.
uint16_t remote_mtu_type = IBV_MTU_1024;
if (remote_msg.msg_len > HELLO_MSG_LEN_MIN) {
// TODO: Read Hello Message customized header
// Just for future use, should not happen now
// Read Hello Message customized data
uint16_t remote_msg_ext_len = remote_msg.msg_len - HELLO_MSG_LEN_MIN;
uint8_t ext_data[remote_msg_ext_len];
if (ep->ReadFromFd(ext_data, remote_msg_ext_len) < 0) {
const int saved_errno = errno;
PLOG(WARNING) << "Fail to get Hello Message ext fields from client:" << s->description();
s->SetFailed(saved_errno, "Fail to complete rdma handshake from %s: %s",
s->description().c_str(), berror(saved_errno));
ep->_state = FAILED;
return NULL;
}
remote_msg.ExtDeserialize(ext_data, remote_msg_ext_len);
if (remote_msg_ext_len >= 2) {
// mtu_type field is valid
remote_mtu_type = remote_msg.mtu_type;
}
// TODO: other extern fields
}

if (!HelloNegotiationValid(remote_msg)) {
Expand All @@ -652,7 +731,9 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
} else {
ep->_state = S_BRINGUP_QP;
if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num) < 0) {
// use the minimum of local mtu type and remote mtu type
uint16_t min_mtu_type = std::min(local_mtu_type, remote_mtu_type);
if (ep->BringUpQp(remote_msg.lid, remote_msg.gid, remote_msg.qp_num, min_mtu_type) < 0) {
LOG(WARNING) << "Fail to bringup QP, fallback to tcp:"
<< s->description();
rdma_transport->_rdma_state = RdmaTransport::RDMA_OFF;
Expand Down Expand Up @@ -681,9 +762,11 @@ void* RdmaEndpoint::ProcessHandshakeAtServer(void* arg) {
// Only happens in UT
local_msg.qp_num = 0;
}
local_msg.mtu_type = local_mtu_type;
}
memcpy(data, MAGIC_STR, 4);
local_msg.Serialize((char*)data + 4);
local_msg.BaseSerialize((char*)data + 4);
local_msg.ExtSerialize((char*)data + HELLO_MSG_LEN_MIN);
if (ep->WriteToFd(data, g_rdma_hello_msg_len) < 0) {
const int saved_errno = errno;
PLOG(WARNING) << "Fail to send Hello Message to client:" << s->description();
Expand Down Expand Up @@ -1232,12 +1315,27 @@ int RdmaEndpoint::AllocateResources() {
return 0;
}

int RdmaEndpoint::BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num) {
int RdmaEndpoint::BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num, uint16_t mtu_type) {
if (BAIDU_UNLIKELY(g_skip_rdma_init)) {
// For UT
return 0;
}

if (mtu_type == IBV_MTU_256) {
LOG(INFO) << "negotiated mtu is 256";
} else if (mtu_type == IBV_MTU_512) {
LOG(INFO) << "negotiated mtu is 512";
} else if (mtu_type == IBV_MTU_1024) {
LOG(INFO) << "negotiated mtu is 1024";
} else if (mtu_type == IBV_MTU_2048) {
LOG(INFO) << "negotiated mtu is 2048";
} else if (mtu_type == IBV_MTU_4096) {
LOG(INFO) << "negotiated mtu is 4096";
Comment on lines +1325 to +1333
} else {
LOG(ERROR) << "unknown mtu " << mtu_type;
return -1;
}

ibv_qp_attr attr;

attr.qp_state = IBV_QPS_INIT;
Expand Down Expand Up @@ -1275,7 +1373,7 @@ int RdmaEndpoint::BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num) {
}

attr.qp_state = IBV_QPS_RTR;
attr.path_mtu = IBV_MTU_1024; // TODO: support more mtu in future
attr.path_mtu = ibv_mtu(mtu_type);
attr.ah_attr.grh.dgid = gid;
attr.ah_attr.grh.flow_label = 0;
attr.ah_attr.grh.sgid_index = GetRdmaGidIndex();
Expand Down
3 changes: 2 additions & 1 deletion src/brpc/rdma/rdma_endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,11 @@ friend class Socket;
// lid: remote LID
// gid: remote GID
// qp_num: remote QP number
// mtu_type: the minimum of local mtu_type and remote mtu_type
// Return:
// 0: success
// -1: failed, errno set
int BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num);
int BringUpQp(uint16_t lid, ibv_gid gid, uint32_t qp_num, uint16_t mtu_type);

// Get event from comp channel and ack the events
int GetAndAckEvents(SocketUniquePtr& s);
Expand Down
41 changes: 41 additions & 0 deletions src/brpc/rdma/rdma_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ static int g_comp_vector_index = 0;

butil::atomic<bool> g_rdma_available(false);

static uint16_t local_mtu_type = IBV_MTU_4096;

DEFINE_int32(rdma_max_sge, 0, "Max SGE num in a WR");
DEFINE_string(rdma_device, "", "The name of the HCA device used "
"(Empty means using the first active device)");
Expand Down Expand Up @@ -455,6 +457,36 @@ static ibv_context* OpenDevice(int num_total, int* num_available_devices) {
return ret_context;
}

static uint16_t detect_mtu(struct ibv_context* ctx, int port_num) {
struct ibv_port_attr port_attr;

if (ibv_query_port(ctx, port_num, &port_attr)) {
LOG(ERROR) << "ibv_query_port failed";
Comment on lines +462 to +464
return 0;
}

LOG(INFO) << "local active mtu type:" << port_attr.active_mtu
<< ", max mtu type:" << port_attr.max_mtu;

uint16_t mtu_type = port_attr.active_mtu;
if (mtu_type == IBV_MTU_256) {
LOG(INFO) << "local mtu is 256";
} else if (mtu_type == IBV_MTU_512) {
LOG(INFO) << "local mtu is 512";
} else if (mtu_type == IBV_MTU_1024) {
LOG(INFO) << "local mtu is 1024";
} else if (mtu_type == IBV_MTU_2048) {
LOG(INFO) << "local mtu is 2048";
} else if (mtu_type == IBV_MTU_4096) {
LOG(INFO) << "local mtu is 4096";
} else {
LOG(ERROR) << "unknown mtu type " << mtu_type;
return 0;
}

return mtu_type;
}

static void GlobalRdmaInitializeOrDieImpl() {
if (BAIDU_UNLIKELY(g_skip_rdma_init)) {
// Just for UT
Expand Down Expand Up @@ -549,6 +581,11 @@ static void GlobalRdmaInitializeOrDieImpl() {
g_max_sge = attr.max_sge;
}

local_mtu_type = detect_mtu(g_context, g_port_num);
if (!local_mtu_type) {
PLOG(ERROR) << "Fail to get local mtu type";
ExitWithError();
}
// Initialize RDMA memory pool (block_pool)
if (!InitBlockPool(RdmaRegisterMemory)) {
PLOG(ERROR) << "Fail to initialize RDMA memory pool";
Expand Down Expand Up @@ -701,6 +738,10 @@ bool SupportedByRdma(std::string protocol) {
return false;
}

uint16_t GetLocalMtuType() {
return local_mtu_type;
}

bool InitPollingModeWithTag(bthread_tag_t tag,
std::function<void(void)> callback,
std::function<void(void)> init_fn,
Expand Down
1 change: 1 addition & 0 deletions src/brpc/rdma/rdma_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ void GlobalDisableRdma();
// If the given protocol supported by RDMA
bool SupportedByRdma(std::string protocol);

uint16_t GetLocalMtuType();
} // namespace rdma
} // namespace brpc
#else
Expand Down
Loading