Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package datadog.trace.instrumentation.kafka_clients;

import static datadog.trace.api.Functions.UTF8_BYTES_TO_STRING;
import static datadog.trace.api.telemetry.LogCollector.EXCLUDE_TELEMETRY;
import static datadog.trace.instrumentation.kafka_clients.KafkaDecorator.KAFKA_PRODUCED_KEY;
import static java.nio.charset.StandardCharsets.UTF_8;

import datadog.trace.api.Config;
import datadog.trace.api.Functions;
import datadog.trace.bootstrap.instrumentation.api.AgentPropagation;
import datadog.trace.bootstrap.instrumentation.api.AgentPropagation.ContextVisitor;
import java.nio.ByteBuffer;
import java.util.Base64;
import java.util.function.Function;
import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.header.Headers;
import org.slf4j.Logger;
Expand All @@ -19,10 +23,17 @@ public class TextMapExtractAdapter implements ContextVisitor<Headers> {
public static final TextMapExtractAdapter GETTER =
new TextMapExtractAdapter(Config.get().isKafkaClientBase64DecodingEnabled());

private final Base64.Decoder base64;
private final Function<byte[], String> headerValueTransformer;
private final Base64.Decoder decoder;

public TextMapExtractAdapter(boolean base64DecodeHeaders) {
this.base64 = base64DecodeHeaders ? Base64.getDecoder() : null;
public TextMapExtractAdapter(boolean decodeBase64Headers) {
if (decodeBase64Headers) {
this.headerValueTransformer = Functions.base64Decode(UTF_8);
this.decoder = Base64.getDecoder();
} else {
this.headerValueTransformer = UTF8_BYTES_TO_STRING;
this.decoder = null;
}
}

@Override
Expand All @@ -33,10 +44,12 @@ public void forEachKey(Headers carrier, AgentPropagation.KeyClassifier classifie
if (null == value) {
continue;
}
if (base64 != null) {
value = base64.decode(value);
String decoded = headerValueTransformer.apply(value);
if (decoded == null) {
log.debug(EXCLUDE_TELEMETRY, "Failed to Base64-decode Kafka header '{}', skipping", key);
continue;
}
if (!classifier.accept(key, new String(value, UTF_8))) {
if (!classifier.accept(key, decoded)) {
return;
}
}
Expand All @@ -47,11 +60,11 @@ public long extractTimeInQueueStart(Headers carrier) {
if (null != header) {
try {
ByteBuffer buf = ByteBuffer.allocate(8);
buf.put(base64 != null ? base64.decode(header.value()) : header.value());
buf.put(decoder != null ? decoder.decode(header.value()) : header.value());
buf.flip();
return buf.getLong();
} catch (Exception e) {
log.debug("Unable to get kafka produced time", e);
log.debug(EXCLUDE_TELEMETRY, "Unable to get kafka produced time", e);
}
}
return 0;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datadog.trace.api.datastreams.DataStreamsTags
import datadog.trace.api.datastreams.DataStreamsTransactionExtractor
import datadog.trace.api.config.TraceInstrumentationConfig
import datadog.trace.api.config.TracerConfig
import datadog.trace.instrumentation.kafka_common.ClusterIdHolder

import static datadog.trace.agent.test.utils.TraceUtils.basicSpan
Expand Down Expand Up @@ -1540,9 +1541,32 @@ class KafkaClientDataStreamsDisabledForkedTest extends KafkaClientTestBase {
}

class KafkaClientContextSwapForkedTest extends KafkaClientV0ForkedTest {
@Override
void configurePreAgent() {
super.configurePreAgent()
injectSysConfig(TraceInstrumentationConfig.LEGACY_CONTEXT_MANAGER_ENABLED, "false")
}
}

class KafkaClientBadBase64HeaderForkedTest extends KafkaClientV0ForkedTest {
def "producer span is created when message carries non-Base64 headers and base64 decoding is enabled"() {
setup:
injectSysConfig(TraceInstrumentationConfig.KAFKA_CLIENT_BASE64_DECODING_ENABLED, "true")
injectSysConfig(TracerConfig.HEADER_TAGS, "x-custom-header:my.custom.tag")
def senderProps = KafkaTestUtils.senderProps(embeddedKafka.getBrokersAsString())
def producer = new KafkaProducer<String, String>(senderProps, new StringSerializer(), new StringSerializer())

when:
def headers = new RecordHeaders([
new RecordHeader("x-custom-header", "not-valid-base64!@#".getBytes(StandardCharsets.UTF_8)),
new RecordHeader("x-another-header", "also-not-base64!!".getBytes(StandardCharsets.UTF_8))
])
producer.send(new ProducerRecord<>(SHARED_TOPIC, 0, null, "hello", headers)).get()

then:
TEST_WRITER.waitForTraces(1)
!TEST_WRITER.isEmpty()

cleanup:
producer.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ import com.google.common.io.BaseEncoding
import datadog.trace.agent.test.InstrumentationSpecification
import datadog.trace.bootstrap.instrumentation.api.AgentPropagation
import datadog.trace.instrumentation.kafka_clients.TextMapExtractAdapter
import java.nio.charset.StandardCharsets
import org.apache.kafka.common.header.Headers
import org.apache.kafka.common.header.internals.RecordHeader
import org.apache.kafka.common.header.internals.RecordHeaders

import java.nio.charset.StandardCharsets

class TextMapExtractAdapterTest extends InstrumentationSpecification {

def "check can decode base64 mangled headers"() {
Expand All @@ -32,4 +31,27 @@ class TextMapExtractAdapterTest extends InstrumentationSpecification {
where:
base64Decode << [true, false]
}

def "invalid base64 header is skipped and subsequent valid headers are still processed"() {
Comment thread
amarziali marked this conversation as resolved.
given:
def validBase64 = BaseEncoding.base64().encode("bar".getBytes(StandardCharsets.UTF_8))
Headers headers = new RecordHeaders([
new RecordHeader("bad-key", "not-valid-base64!@#".getBytes(StandardCharsets.UTF_8)),
new RecordHeader("good-key", validBase64.getBytes(StandardCharsets.UTF_8))
])
TextMapExtractAdapter adapter = new TextMapExtractAdapter(true)
when:
Map<String, String> extracted = [:]
adapter.forEachKey(headers, new AgentPropagation.KeyClassifier() {
@Override
boolean accept(String key, String value) {
extracted[key] = value
return true
}
})
then:
noExceptionThrown()
!extracted.containsKey("bad-key")
extracted["good-key"] == "bar"
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package datadog.trace.instrumentation.kafka_clients38;

import static datadog.trace.api.Functions.UTF8_BYTES_TO_STRING;
import static datadog.trace.api.telemetry.LogCollector.EXCLUDE_TELEMETRY;
import static java.nio.charset.StandardCharsets.UTF_8;

import datadog.trace.api.Config;
import datadog.trace.api.Functions;
import datadog.trace.bootstrap.instrumentation.api.AgentPropagation;
import datadog.trace.bootstrap.instrumentation.api.AgentPropagation.ContextVisitor;
import java.nio.ByteBuffer;
import java.util.Base64;
import java.util.function.Function;
import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.header.Headers;
import org.slf4j.Logger;
Expand All @@ -18,10 +22,17 @@ public class TextMapExtractAdapter implements ContextVisitor<Headers> {
public static final TextMapExtractAdapter GETTER =
new TextMapExtractAdapter(Config.get().isKafkaClientBase64DecodingEnabled());

private final Base64.Decoder base64;
private final Function<byte[], String> headerValueTransformer;
private final Base64.Decoder decoder;

public TextMapExtractAdapter(boolean base64DecodeHeaders) {
this.base64 = base64DecodeHeaders ? Base64.getDecoder() : null;
public TextMapExtractAdapter(boolean decodeBase64Headers) {
if (decodeBase64Headers) {
this.headerValueTransformer = Functions.base64Decode(UTF_8);
this.decoder = Base64.getDecoder();
} else {
this.headerValueTransformer = UTF8_BYTES_TO_STRING;
this.decoder = null;
}
}

@Override
Expand All @@ -32,10 +43,12 @@ public void forEachKey(Headers carrier, AgentPropagation.KeyClassifier classifie
if (null == value) {
continue;
}
if (base64 != null) {
value = base64.decode(value);
String decoded = headerValueTransformer.apply(value);
if (decoded == null) {
log.debug(EXCLUDE_TELEMETRY, "Failed to Base64-decode Kafka header '{}', skipping", key);
continue;
}
if (!classifier.accept(key, new String(value, UTF_8))) {
if (!classifier.accept(key, decoded)) {
return;
}
}
Expand All @@ -46,11 +59,11 @@ public long extractTimeInQueueStart(Headers carrier) {
if (null != header) {
try {
ByteBuffer buf = ByteBuffer.allocate(8);
buf.put(base64 != null ? base64.decode(header.value()) : header.value());
buf.put(decoder != null ? decoder.decode(header.value()) : header.value());
buf.flip();
return buf.getLong();
} catch (Exception e) {
log.debug("Unable to get kafka produced time", e);
log.debug(EXCLUDE_TELEMETRY, "Unable to get kafka produced time", e);
}
}
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import datadog.trace.agent.test.asserts.TraceAssert
import datadog.trace.agent.test.naming.VersionedNamingTestBase
import datadog.trace.api.Config
import datadog.trace.api.config.TraceInstrumentationConfig
import datadog.trace.api.config.TracerConfig
import datadog.trace.api.DDTags
import datadog.trace.api.datastreams.DataStreamsTags
import datadog.trace.bootstrap.instrumentation.api.InstrumentationTags
Expand Down Expand Up @@ -1210,9 +1211,32 @@ class KafkaClientDataStreamsDisabledForkedTest extends KafkaClientTestBase {
}

class KafkaClientContextSwapForkedTest extends KafkaClientV0ForkedTest {
@Override
void configurePreAgent() {
super.configurePreAgent()
injectSysConfig(TraceInstrumentationConfig.LEGACY_CONTEXT_MANAGER_ENABLED, "false")
}
}

class KafkaClientBadBase64HeaderForkedTest extends KafkaClientV0ForkedTest {
def "producer span is created when message carries non-Base64 headers and base64 decoding is enabled"() {
setup:
injectSysConfig(TraceInstrumentationConfig.KAFKA_CLIENT_BASE64_DECODING_ENABLED, "true")
injectSysConfig(TracerConfig.HEADER_TAGS, "x-custom-header:my.custom.tag")
def producerProps = KafkaTestUtils.producerProps(embeddedKafka.getBrokersAsString())
def producer = new KafkaProducer<String, String>(producerProps, new StringSerializer(), new StringSerializer())

when:
def headers = new RecordHeaders([
new RecordHeader("x-custom-header", "not-valid-base64!@#".getBytes(StandardCharsets.UTF_8)),
new RecordHeader("x-another-header", "also-not-base64!!".getBytes(StandardCharsets.UTF_8))
])
producer.send(new ProducerRecord<>(SHARED_TOPIC, 0, null, "hello", headers)).get()

then:
TEST_WRITER.waitForTraces(1)
!TEST_WRITER.isEmpty()

cleanup:
producer.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ import com.google.common.io.BaseEncoding
import datadog.trace.agent.test.InstrumentationSpecification
import datadog.trace.bootstrap.instrumentation.api.AgentPropagation
import datadog.trace.instrumentation.kafka_clients38.TextMapExtractAdapter
import java.nio.charset.StandardCharsets
import org.apache.kafka.common.header.Headers
import org.apache.kafka.common.header.internals.RecordHeader
import org.apache.kafka.common.header.internals.RecordHeaders

import java.nio.charset.StandardCharsets

class TextMapExtractAdapterTest extends InstrumentationSpecification {

def "check can decode base64 mangled headers"() {
Expand All @@ -32,4 +31,27 @@ class TextMapExtractAdapterTest extends InstrumentationSpecification {
where:
base64Decode << [true, false]
}

def "invalid base64 header is skipped and subsequent valid headers are still processed"() {
given:
def validBase64 = BaseEncoding.base64().encode("bar".getBytes(StandardCharsets.UTF_8))
Headers headers = new RecordHeaders([
new RecordHeader("bad-key", "not-valid-base64!@#".getBytes(StandardCharsets.UTF_8)),
new RecordHeader("good-key", validBase64.getBytes(StandardCharsets.UTF_8))
])
TextMapExtractAdapter adapter = new TextMapExtractAdapter(true)
when:
Map<String, String> extracted = [:]
adapter.forEachKey(headers, new AgentPropagation.KeyClassifier() {
@Override
boolean accept(String key, String value) {
extracted[key] = value
return true
}
})
then:
noExceptionThrown()
!extracted.containsKey("bad-key")
extracted["good-key"] == "bar"
}
}
29 changes: 29 additions & 0 deletions internal-api/src/main/java/datadog/trace/api/Functions.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package datadog.trace.api;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.function.Function.identity;

import datadog.trace.bootstrap.instrumentation.api.UTF8BytesString;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.charset.Charset;
import java.util.Base64;
import java.util.Locale;
import java.util.function.BiFunction;
import java.util.function.Function;
Expand Down Expand Up @@ -174,4 +177,30 @@ public T apply(Object input) {
}
}
}

public static final Function<byte[], String> UTF8_BYTES_TO_STRING =
bytes -> new String(bytes, UTF_8);

public static Function<byte[], String> base64Decode(Charset charset) {
return new Base64Decode(charset);
}

private static final class Base64Decode implements Function<byte[], String> {
private final Base64.Decoder decoder;
private final Charset charset;

private Base64Decode(Charset charset) {
this.decoder = Base64.getDecoder();
this.charset = charset;
}

@Override
public String apply(byte[] bytes) {
try {
return new String(decoder.decode(bytes), charset);
} catch (final Exception ignored) {
return null;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package datadog.trace.api;

import static datadog.trace.api.Functions.UTF8_BYTES_TO_STRING;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;

import java.util.Base64;
import java.util.function.Function;
import org.junit.jupiter.api.Test;

class FunctionsBase64Test {

@Test
void utf8BytesToStringConvertsBytes() {
byte[] bytes = "hello".getBytes(UTF_8);
assertEquals("hello", UTF8_BYTES_TO_STRING.apply(bytes));
}

@Test
void base64DecodeDecodesValidInput() {
String original = "x-datadog-trace-id";
byte[] encoded = Base64.getEncoder().encode(original.getBytes(UTF_8));

Function<byte[], String> decoder = Functions.base64Decode(UTF_8);
assertEquals(original, decoder.apply(encoded));
}

@Test
void base64DecodeReturnsNullForInvalidBase64() {
Function<byte[], String> decoder = Functions.base64Decode(UTF_8);
assertNull(decoder.apply("not-valid-base64!@#".getBytes(UTF_8)));
}

@Test
void base64DecodeReturnsNullForUrlSafeChars() {
// URL-safe Base64 uses '-' and '_' which the standard decoder rejects
Function<byte[], String> decoder = Functions.base64Decode(UTF_8);
assertNull(decoder.apply("abc-def_ghi".getBytes(UTF_8)));
}

@Test
void base64DecodeInstanceIsNotNull() {
assertNotNull(Functions.base64Decode(UTF_8));
}
}