Skip to content

Commit 7d9b189

Browse files
committed
Tested on the GPU RTX4090 with cuda 12.x
1 parent b2e34d0 commit 7d9b189

File tree

7 files changed

+543
-497
lines changed

7 files changed

+543
-497
lines changed

mlir/cuda-tile/.gitignore

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
11
*.ptx
22
*.cubin
33
*.fatbin
4+
*.bc
5+
*.ll
6+
*.o
7+
*.s
8+
*.so
9+
*.dylib
10+
*.a
11+
*.dll
12+
*.obj
13+
*.exe
14+
*.log
15+
*.cache
16+
*.tmp
17+
*.bin
18+
*.out

mlir/cuda-tile/README.md

Lines changed: 270 additions & 464 deletions
Large diffs are not rendered by default.

mlir/cuda-tile/Toy/cuda_wrapper/cuda_shim.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ cuda_shim_load_module_from_file(uint64_t file_path_ptr,
330330
uint64_t /*file_path_nbytes*/) {
331331
auto file_path_cstr =
332332
reinterpret_cast<const char *>(asHostCPtr(file_path_ptr));
333-
// fprintf(stdout, "%s", file_path_cstr);
333+
debug_print("Loading CUDA module from file: %s\n", file_path_cstr);
334334
CUmodule module = nullptr;
335335
ScopedContext scopedContext;
336336
CUDA_REPORT_IF_ERROR(cuModuleLoad(&module, file_path_cstr));
@@ -519,7 +519,7 @@ extern "C" void cuda_shim_ctx_synchronize(void) { mgpuCtxSynchronize(); }
519519

520520
// only for debugging
521521
extern "C" void cuda_debug_dump_float(uint64_t dptr, int n) {
522-
auto *p = reinterpret_cast<const float*>(static_cast<uintptr_t>(dptr));
522+
auto *p = reinterpret_cast<const float *>(static_cast<uintptr_t>(dptr));
523523
for (uint32_t i = 0; i < n; ++i) {
524524
fprintf(stderr, "i=%u v=%f\n", i, p[i]);
525525
}

mlir/cuda-tile/Toy/include/cuda_shim/CudaShimBuilder.hpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,74 @@ inline unsigned long getNbytes(mlir::Type tensorType) {
296296
ranked_tensor_type.getElementTypeBitWidth(),
297297
8);
298298
}
299+
300+
extern "C" {
301+
// Load module from PTX or CUBIN image in memory.
302+
// Driver API supports cuModuleLoadDataEx for both PTX and cubin (it
303+
// auto-detects).
304+
uint64_t cuda_shim_load_module_from_image(uint64_t image_ptr,
305+
uint64_t image_nbytes);
306+
uint64_t cuda_shim_load_module_jit_from_image(uint64_t image_ptr,
307+
uint64_t image_nbytes,
308+
int opt_level);
309+
310+
uint64_t cuda_shim_load_module_from_file(uint64_t file_path_ptr,
311+
uint64_t /*file_path_nbytes*/);
312+
313+
void cuda_shim_unload_module(uint64_t module_handle);
314+
315+
uint64_t cuda_shim_malloc(uint64_t nbytes, uint64_t stream,
316+
bool is_host_shared);
317+
318+
void cuda_shim_free(uint64_t dptr, uint64_t stream);
319+
320+
void cuda_shim_memset32(uint64_t dptr, uint32_t value, uint64_t count_dwords,
321+
uint64_t stream);
322+
void cuda_shim_memset16(uint64_t dptr, uint32_t value, uint64_t count_dwords,
323+
uint64_t stream);
324+
325+
uint64_t cuda_shim_stream_create(void);
326+
327+
void cuda_shim_stream_destroy(uint64_t stream);
328+
329+
void cuda_shim_stream_synchronize(uint64_t stream);
330+
331+
uint64_t cuda_shim_event_create(void);
332+
333+
void cuda_shim_event_destroy(uint64_t ev);
334+
335+
void cuda_shim_event_record(uint64_t ev, uint64_t stream);
336+
337+
void cuda_shim_event_synchronize(uint64_t ev);
338+
339+
void cuda_shim_stream_wait_event(uint64_t stream, uint64_t ev);
340+
341+
// ----------------------------- Memcpy (raw ABI) --------------------------
342+
// Host pointers are passed as uint64_t. This is the key of 2A.
343+
344+
void cuda_shim_memcpy_h2d(uint64_t dst_dptr, uint64_t src_hptr,
345+
uint64_t nbytes);
346+
347+
void cuda_shim_memcpy_d2h(uint64_t dst_hptr, uint64_t src_dptr,
348+
uint64_t nbytes);
349+
350+
void cuda_shim_launch_packed(uint64_t module_handle, uint64_t kernel_name_ptr,
351+
uint32_t gridX, uint32_t gridY, uint32_t gridZ,
352+
uint32_t blockX, uint32_t blockY, uint32_t blockZ,
353+
uint32_t sharedMemBytes, uint64_t stream,
354+
uint64_t arg_data_ptr, uint64_t arg_sizes_ptr,
355+
uint32_t num_args);
356+
357+
// Convenience: 1D launch, shared=0, stream optional
358+
void cuda_shim_launch_block_packed(uint64_t module_handle,
359+
uint64_t kernel_name_ptr, uint32_t blockX,
360+
uint32_t blockY, uint32_t blockZ,
361+
uint64_t stream, uint64_t arg_data_ptr,
362+
uint64_t arg_sizes_ptr, uint32_t num_args);
363+
364+
// Optional: global sync (avoid in async pipeline; prefer event/stream sync)
365+
void cuda_shim_ctx_synchronize(void);
366+
367+
// only for debugging
368+
void cuda_debug_dump_float(uint64_t dptr, int n);
369+
}

mlir/cuda-tile/Toy/include/toy/Passes.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ std::unique_ptr<mlir::Pass> createGpuOutlinePass(std::string grid = "1,1,1");
3434

3535
std::unique_ptr<mlir::Pass> createCudaTileLoweringPass();
3636

37-
std::unique_ptr<mlir::Pass>
38-
createEmbedCudaTileBinaryPass(std::string tileirasExe = "tileiras",
39-
std::string gpuName = "sm_120");
37+
std::unique_ptr<mlir::Pass> createEmbedCudaTileBinaryPass(
38+
std::string tileirasExe = "tileiras", std::string gpuName = "sm_120",
39+
std::string cubinOrPtxPath = "", bool useCache = true);
4040

4141
} // namespace toy
4242
} // namespace mlir

mlir/cuda-tile/Toy/mlir/EmitCudaTile.cpp

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
#include "toy/Dialect.h"
88
#include "llvm/ADT/SmallVector.h"
99
#include "llvm/ADT/StringRef.h"
10+
#include "llvm/Support/DebugLog.h"
1011
#include "llvm/Support/FileSystem.h"
1112
#include "llvm/Support/MemoryBuffer.h"
1213
#include "llvm/Support/Program.h"
1314
#include "llvm/Support/raw_ostream.h"
15+
#include <string>
1416
#include <system_error>
1517

1618
using namespace llvm;
@@ -84,9 +86,13 @@ struct EmbedCudaTileBinaryPass
8486

8587
std::string tileirasExe;
8688
std::string gpuName;
89+
std::string cubinOrPtxPath;
90+
bool useCache;
8791

88-
EmbedCudaTileBinaryPass(std::string tileirasExe, std::string gpuName)
89-
: tileirasExe(std::move(tileirasExe)), gpuName(std::move(gpuName)) {}
92+
EmbedCudaTileBinaryPass(std::string tileirasExe, std::string gpuName,
93+
std::string cubinOrPtxPath, bool useCache)
94+
: tileirasExe(std::move(tileirasExe)), gpuName(std::move(gpuName)),
95+
cubinOrPtxPath(std::move(cubinOrPtxPath)), useCache(useCache) {}
9096

9197
void runOnOperation() override {
9298
ModuleOp top = getOperation();
@@ -126,13 +132,38 @@ struct EmbedCudaTileBinaryPass
126132
return;
127133
}
128134

129-
if (std::error_code ec =
130-
createTemporaryFile(cudaBinPath, "cuda_tile", "bin")) {
131-
op->emitError() << "failed to create temp out bin: " << ec.message();
132-
signalPassFailure();
135+
if (cubinOrPtxPath.empty()) {
136+
if (std::error_code ec =
137+
createTemporaryFile(cudaBinPath, "cuda_tile", "bin")) {
138+
op->emitError() << "failed to create temp out bin: " << ec.message();
139+
signalPassFailure();
140+
return;
141+
}
142+
} else {
143+
if (!useCache) {
144+
if (llvm::sys::fs::exists(cubinOrPtxPath)) {
145+
op->emitWarning() << "cuda binary file exist " << cubinOrPtxPath
146+
<< ", tileiras will overwrite it.";
147+
std::error_code ec = llvm::sys::fs::remove(cubinOrPtxPath);
148+
if (ec) {
149+
op->emitError() << "failed to remove existing cuda binary file: "
150+
<< ec.message();
151+
signalPassFailure();
152+
return;
153+
}
154+
}
155+
}
156+
cudaBinPath = cubinOrPtxPath;
157+
}
158+
159+
if (useCache && llvm::sys::fs::exists(cudaBinPath)) {
160+
LDBG() << "cuda binary file exist and will be reused: " << cudaBinPath
161+
<< "\n";
133162
return;
134163
}
135164

165+
// ! [FIXME]: please comment out this following code since this is only
166+
// for testing.
136167
if (failed(writeFileBytes(inPath, tilebcBytes))) {
137168
op->emitError() << "failed to write temp tilebc";
138169
signalPassFailure();
@@ -145,6 +176,8 @@ struct EmbedCudaTileBinaryPass
145176
}
146177
});
147178

179+
LDBG() << "cuda binary path: " << cudaBinPath << "\n";
180+
148181
top->walk([&](toy::LaunchGpuOp launchOp) {
149182
// ---- Step D: read cuda binary bytes ----
150183
auto binBytesOrErr = readFileBytes(cudaBinPath);
@@ -189,8 +222,10 @@ struct EmbedCudaTileBinaryPass
189222
namespace mlir::toy {
190223

191224
std::unique_ptr<mlir::Pass>
192-
createEmbedCudaTileBinaryPass(std::string tileirasExe, std::string gpuName) {
193-
return std::make_unique<EmbedCudaTileBinaryPass>(tileirasExe, gpuName);
225+
createEmbedCudaTileBinaryPass(std::string tileirasExe, std::string gpuName,
226+
std::string cubinOrPtxPath, bool useCache) {
227+
return std::make_unique<EmbedCudaTileBinaryPass>(tileirasExe, gpuName,
228+
cubinOrPtxPath, useCache);
194229
};
195230

196231
}; // namespace mlir::toy

0 commit comments

Comments
 (0)