Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,28 +79,47 @@ private void generatePciXml(StringBuilder gpuBuilder) {
gpuBuilder.append(" <driver name='vfio'/>\n");
gpuBuilder.append(" <source>\n");

// Parse the bus address (e.g., 00:02.0) into domain, bus, slot, function
String domain = "0x0000";
String bus = "0x00";
String slot = "0x00";
String function = "0x0";
// Parse the bus address into domain, bus, slot, function. Two input formats are accepted:
// - "dddd:bb:ss.f" full PCI address with domain (e.g. 0000:00:02.0)
// - "bb:ss.f" legacy short BDF; domain defaults to 0000
// Each segment is parsed as a hex integer and formatted with fixed widths
// (domain: 4 hex digits, bus/slot: 2 hex digits, function: 1 hex digit) to
// produce canonical libvirt XML values regardless of input casing or padding.
int domainVal = 0, busVal = 0, slotVal = 0, funcVal = 0;

if (busAddress != null && !busAddress.isEmpty()) {
String[] parts = busAddress.split(":");
if (parts.length > 1) {
bus = "0x" + parts[0];
String[] slotFunctionParts = parts[1].split("\\.");
if (slotFunctionParts.length > 0) {
slot = "0x" + slotFunctionParts[0];
if (slotFunctionParts.length > 1) {
function = "0x" + slotFunctionParts[1].trim();
}
try {
String slotFunction;
if (parts.length == 3) {
domainVal = Integer.parseInt(parts[0], 16);
busVal = Integer.parseInt(parts[1], 16);
slotFunction = parts[2];
} else if (parts.length == 2) {
busVal = Integer.parseInt(parts[0], 16);
slotFunction = parts[1];
} else {
throw new IllegalArgumentException("Invalid PCI bus address format: '" + busAddress + "'");
}
String[] sf = slotFunction.split("\\.");
if (sf.length == 2) {
slotVal = Integer.parseInt(sf[0], 16);
funcVal = Integer.parseInt(sf[1].trim(), 16);
} else {
throw new IllegalArgumentException("Invalid PCI bus address format: '" + busAddress + "'");
}
} catch (NumberFormatException e) {
throw new IllegalArgumentException("Invalid PCI bus address format: '" + busAddress + "'", e);
}
}

String domain = String.format("0x%04x", domainVal);
String bus = String.format("0x%02x", busVal);
String slot = String.format("0x%02x", slotVal);
String function = String.format("0x%x", funcVal);

gpuBuilder.append(" <address domain='").append(domain).append("' bus='").append(bus).append("' slot='")
.append(slot).append("' function='").append(function.trim()).append("'/>\n");
.append(slot).append("' function='").append(function).append("'/>\n");
gpuBuilder.append(" </source>\n");
gpuBuilder.append("</hostdev>\n");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,145 @@ public void testGpuDef_withComplexPciAddress() {
assertTrue(gpuXml.contains("</hostdev>"));
}

@Test
public void testGpuDef_withFullPciAddressDomainZero() {
LibvirtGpuDef gpuDef = new LibvirtGpuDef();
VgpuTypesInfo pciGpuInfo = new VgpuTypesInfo(
GpuDevice.DeviceType.PCI,
"passthrough",
"passthrough",
"0000:00:02.0",
"10de",
"NVIDIA Corporation",
"1b38",
"Tesla T4"
);
gpuDef.defGpu(pciGpuInfo);

String gpuXml = gpuDef.toString();

assertTrue(gpuXml.contains("<address domain='0x0000' bus='0x00' slot='0x02' function='0x0'/>"));
}

@Test
public void testGpuDef_withFullPciAddressNonZeroDomain() {
LibvirtGpuDef gpuDef = new LibvirtGpuDef();
VgpuTypesInfo pciGpuInfo = new VgpuTypesInfo(
GpuDevice.DeviceType.PCI,
"passthrough",
"passthrough",
"0001:65:00.0",
"10de",
"NVIDIA Corporation",
"1b38",
"Tesla T4"
);
gpuDef.defGpu(pciGpuInfo);

String gpuXml = gpuDef.toString();

assertTrue(gpuXml.contains("<address domain='0x0001' bus='0x65' slot='0x00' function='0x0'/>"));
}

@Test
public void testGpuDef_withNvidiaStyleEightDigitDomain() {
// nvidia-smi reports PCI addresses with an 8-digit domain (e.g. "00000001:af:00.1").
// generatePciXml must normalize it to the canonical 4-digit form "0x0001".
LibvirtGpuDef gpuDef = new LibvirtGpuDef();
VgpuTypesInfo pciGpuInfo = new VgpuTypesInfo(
GpuDevice.DeviceType.PCI,
"passthrough",
"passthrough",
"00000001:af:00.1",
"10de",
"NVIDIA Corporation",
"1b38",
"Tesla T4"
);
gpuDef.defGpu(pciGpuInfo);

String gpuXml = gpuDef.toString();

assertTrue(gpuXml.contains("<address domain='0x0001' bus='0xaf' slot='0x00' function='0x1'/>"));
}

@Test
public void testGpuDef_withFullPciAddressVfNonZeroDomain() {
LibvirtGpuDef gpuDef = new LibvirtGpuDef();
VgpuTypesInfo vfGpuInfo = new VgpuTypesInfo(
GpuDevice.DeviceType.PCI,
"VF-Profile",
"VF-Profile",
"0002:81:00.3",
"10de",
"NVIDIA Corporation",
"1eb8",
"Tesla T4"
);
gpuDef.defGpu(vfGpuInfo);

String gpuXml = gpuDef.toString();

// Non-passthrough NVIDIA VFs should be unmanaged
assertTrue(gpuXml.contains("<hostdev mode='subsystem' type='pci' managed='no' display='off'>"));
assertTrue(gpuXml.contains("<address domain='0x0002' bus='0x81' slot='0x00' function='0x3'/>"));
}

@Test
public void testGpuDef_withLegacyShortBdfDefaultsDomainToZero() {
// Backward compatibility: short BDF with no domain segment must still
// produce a valid libvirt address with domain 0x0000.
LibvirtGpuDef gpuDef = new LibvirtGpuDef();
VgpuTypesInfo pciGpuInfo = new VgpuTypesInfo(
GpuDevice.DeviceType.PCI,
"passthrough",
"passthrough",
"af:00.0",
"10de",
"NVIDIA Corporation",
"1b38",
"Tesla T4"
);
gpuDef.defGpu(pciGpuInfo);

String gpuXml = gpuDef.toString();

assertTrue(gpuXml.contains("<address domain='0x0000' bus='0xaf' slot='0x00' function='0x0'/>"));
}

@Test
public void testGpuDef_withInvalidBusAddressThrows() {
String[] invalidAddresses = {
"notahex:00.0", // non-hex bus
"gg:00:02.0", // non-hex domain
"00:02:03:04", // too many colon-separated parts
"00", // missing slot/function
"00:02", // missing function (no dot)
"00:02.0.1", // extra dot in ss.f
};
for (String addr : invalidAddresses) {
LibvirtGpuDef gpuDef = new LibvirtGpuDef();
VgpuTypesInfo info = new VgpuTypesInfo(
GpuDevice.DeviceType.PCI,
"passthrough",
"passthrough",
addr,
"10de",
"NVIDIA Corporation",
"1b38",
"Tesla T4"
);
gpuDef.defGpu(info);
try {
String ignored = gpuDef.toString();
fail("Expected IllegalArgumentException for address: " + addr + " but got: " + ignored);
} catch (IllegalArgumentException e) {
assertTrue("Exception message should contain the bad address",
e.getMessage().contains(addr));
}
}
}

@Test
public void testGpuDef_withNullDeviceType() {
LibvirtGpuDef gpuDef = new LibvirtGpuDef();
Expand Down
Loading
Loading