11import datadog.trace.api.datastreams.DataStreamsTags
2+ import datadog.trace.api.datastreams.DataStreamsTransactionExtractor
23import datadog.trace.instrumentation.kafka_common.ClusterIdHolder
34
45import static datadog.trace.agent.test.utils.TraceUtils.basicSpan
@@ -1047,6 +1048,126 @@ abstract class KafkaClientTestBase extends VersionedNamingTestBase {
10471048 producer?. close()
10481049 }
10491050
1051+ def " test producer DSM transaction tracking extracts transaction id from headers" () {
1052+ setup :
1053+ if (! isDataStreamsEnabled()) {
1054+ return
1055+ }
1056+
1057+ injectEnvConfig(" DD_DATA_STREAMS_ENABLED" , " true" )
1058+
1059+ // Configure a DSM transaction extractor for KAFKA_PRODUCE_HEADERS
1060+ def extractorsByTypeField = TEST_DATA_STREAMS_MONITORING . getClass(). getDeclaredField(" extractorsByType" )
1061+ extractorsByTypeField. setAccessible(true )
1062+ def oldExtractorsByType = extractorsByTypeField. get(TEST_DATA_STREAMS_MONITORING )
1063+
1064+ def extractor = new DataStreamsTransactionExtractor () {
1065+ String getName () {
1066+ return " kafka-produce-test"
1067+ }
1068+ DataStreamsTransactionExtractor.Type getType () {
1069+ return DataStreamsTransactionExtractor.Type . KAFKA_PRODUCE_HEADERS
1070+ }
1071+ String getValue () {
1072+ return " x-transaction-id"
1073+ }
1074+ }
1075+ def extractorsByType = new EnumMap<> (DataStreamsTransactionExtractor.Type )
1076+ extractorsByType. put(DataStreamsTransactionExtractor.Type . KAFKA_PRODUCE_HEADERS , [extractor])
1077+ extractorsByTypeField. set(TEST_DATA_STREAMS_MONITORING , extractorsByType)
1078+
1079+ def senderProps = KafkaTestUtils . senderProps(embeddedKafka. getBrokersAsString())
1080+ def producer = new KafkaProducer<> (senderProps, new StringSerializer (), new StringSerializer ())
1081+
1082+ def headers = new RecordHeaders ()
1083+ headers. add(new RecordHeader (" x-transaction-id" , " txn-123" . getBytes(StandardCharsets . UTF_8 )))
1084+
1085+ when :
1086+ def record = new ProducerRecord (SHARED_TOPIC , 0 , null , " test-dsm-transaction" , headers)
1087+ producer. send(record). get()
1088+
1089+ then :
1090+ TEST_WRITER . waitForTraces(1 )
1091+ def producedSpan = TEST_WRITER [0 ][0 ]
1092+ producedSpan. getTag(Tags . DSM_TRANSACTION_ID ) == " txn-123"
1093+ producedSpan. getTag(Tags . DSM_TRANSACTION_CHECKPOINT ) == " kafka-produce-test"
1094+
1095+ cleanup :
1096+ extractorsByTypeField?. set(TEST_DATA_STREAMS_MONITORING , oldExtractorsByType)
1097+ producer?. close()
1098+ }
1099+
1100+ def " test consumer DSM transaction tracking extracts transaction id from headers" () {
1101+ setup :
1102+ if (! isDataStreamsEnabled()) {
1103+ return
1104+ }
1105+
1106+ injectEnvConfig(" DD_DATA_STREAMS_ENABLED" , " true" )
1107+
1108+ // Configure a DSM transaction extractor for KAFKA_CONSUME_HEADERS
1109+ def extractorsByTypeField = TEST_DATA_STREAMS_MONITORING . getClass(). getDeclaredField(" extractorsByType" )
1110+ extractorsByTypeField. setAccessible(true )
1111+ def oldExtractorsByType = extractorsByTypeField. get(TEST_DATA_STREAMS_MONITORING )
1112+
1113+ def extractor = new DataStreamsTransactionExtractor () {
1114+ String getName () {
1115+ return " kafka-consume-test"
1116+ }
1117+ DataStreamsTransactionExtractor.Type getType () {
1118+ return DataStreamsTransactionExtractor.Type . KAFKA_CONSUME_HEADERS
1119+ }
1120+ String getValue () {
1121+ return " x-transaction-id"
1122+ }
1123+ }
1124+ def extractorsByType = new EnumMap<> (DataStreamsTransactionExtractor.Type )
1125+ extractorsByType. put(DataStreamsTransactionExtractor.Type . KAFKA_CONSUME_HEADERS , [extractor])
1126+ extractorsByTypeField. set(TEST_DATA_STREAMS_MONITORING , extractorsByType)
1127+
1128+ def kafkaPartition = 0
1129+ def consumerProperties = KafkaTestUtils . consumerProps(" sender" , " false" , embeddedKafka)
1130+ consumerProperties. put(ConsumerConfig . AUTO_OFFSET_RESET_CONFIG , " earliest" )
1131+ def consumer = new KafkaConsumer<String , String > (consumerProperties)
1132+
1133+ def senderProps = KafkaTestUtils . senderProps(embeddedKafka. getBrokersAsString())
1134+ def producer = new KafkaProducer<> (senderProps, new StringSerializer (), new StringSerializer ())
1135+
1136+ consumer. assign(Arrays . asList(new TopicPartition (SHARED_TOPIC , kafkaPartition)))
1137+
1138+ def headers = new RecordHeaders ()
1139+ headers. add(new RecordHeader (" x-transaction-id" , " txn-456" . getBytes(StandardCharsets . UTF_8 )))
1140+
1141+ when :
1142+ def record = new ProducerRecord (SHARED_TOPIC , kafkaPartition, null , " test-dsm-consume-transaction" , headers)
1143+ producer. send(record). get()
1144+
1145+ then :
1146+ TEST_WRITER . waitForTraces(1 )
1147+ def pollResult = KafkaTestUtils . getRecords(consumer)
1148+ def recs = pollResult. records(new TopicPartition (SHARED_TOPIC , kafkaPartition)). iterator()
1149+ recs. hasNext()
1150+ recs. next(). value() == " test-dsm-consume-transaction"
1151+
1152+ // The consume span is created by TracingIterator when iterating over records
1153+ // Find the consumer span with the DSM transaction tags
1154+ TEST_WRITER . waitForTraces(2 )
1155+ def allTraces = TEST_WRITER . toArray() as List<List<DDSpan > >
1156+ def consumerSpan = allTraces. collectMany {
1157+ it
1158+ }. find {
1159+ it. getTag(Tags . DSM_TRANSACTION_ID ) == " txn-456"
1160+ }
1161+ consumerSpan != null
1162+ consumerSpan. getTag(Tags . DSM_TRANSACTION_ID ) == " txn-456"
1163+ consumerSpan. getTag(Tags . DSM_TRANSACTION_CHECKPOINT ) == " kafka-consume-test"
1164+
1165+ cleanup :
1166+ extractorsByTypeField?. set(TEST_DATA_STREAMS_MONITORING , oldExtractorsByType)
1167+ consumer?. close()
1168+ producer?. close()
1169+ }
1170+
10501171 def containerProperties () {
10511172 try {
10521173 // Different class names for test and latestDepTest.
@@ -1057,12 +1178,12 @@ abstract class KafkaClientTestBase extends VersionedNamingTestBase {
10571178 }
10581179
10591180 def producerSpan (
1060- TraceAssert trace ,
1061- Map<String , ?> config ,
1062- DDSpan parentSpan = null ,
1063- boolean partitioned = true ,
1064- boolean tombstone = false ,
1065- String schema = null
1181+ TraceAssert trace ,
1182+ Map<String , ?> config ,
1183+ DDSpan parentSpan = null ,
1184+ boolean partitioned = true ,
1185+ boolean tombstone = false ,
1186+ String schema = null
10661187 ) {
10671188 trace. span {
10681189 serviceName service()
@@ -1104,8 +1225,8 @@ abstract class KafkaClientTestBase extends VersionedNamingTestBase {
11041225 }
11051226
11061227 def queueSpan (
1107- TraceAssert trace ,
1108- DDSpan parentSpan = null
1228+ TraceAssert trace ,
1229+ DDSpan parentSpan = null
11091230 ) {
11101231 trace. span {
11111232 serviceName splitByDestination() ? " $SHARED_TOPIC " : serviceForTimeInQueue()
@@ -1128,12 +1249,12 @@ abstract class KafkaClientTestBase extends VersionedNamingTestBase {
11281249 }
11291250
11301251 def consumerSpan (
1131- TraceAssert trace ,
1132- Map<String , Object > config ,
1133- DDSpan parentSpan = null ,
1134- Range offset = 0 .. 0 ,
1135- boolean tombstone = false ,
1136- boolean distributedRootSpan = ! hasQueueSpan()
1252+ TraceAssert trace ,
1253+ Map<String , Object > config ,
1254+ DDSpan parentSpan = null ,
1255+ Range offset = 0 .. 0 ,
1256+ boolean tombstone = false ,
1257+ boolean distributedRootSpan = ! hasQueueSpan()
11371258 ) {
11381259 trace. span {
11391260 serviceName service()
@@ -1169,12 +1290,12 @@ abstract class KafkaClientTestBase extends VersionedNamingTestBase {
11691290 }
11701291
11711292 def pollSpan (
1172- TraceAssert trace ,
1173- int recordCount = 1 ,
1174- DDSpan parentSpan = null ,
1175- Range offset = 0 .. 0 ,
1176- boolean tombstone = false ,
1177- boolean distributedRootSpan = ! hasQueueSpan()
1293+ TraceAssert trace ,
1294+ int recordCount = 1 ,
1295+ DDSpan parentSpan = null ,
1296+ Range offset = 0 .. 0 ,
1297+ boolean tombstone = false ,
1298+ boolean distributedRootSpan = ! hasQueueSpan()
11781299 ) {
11791300 trace. span {
11801301 serviceName Config . get(). getServiceName()
0 commit comments