|
| 1 | +use bevy::prelude::*; |
| 2 | +use bevy::render::RenderApp; |
| 3 | +use bevy::render::render_resource::{Texture, TextureFormat}; |
| 4 | +use bevy::render::renderer::RenderDevice; |
| 5 | +use bevy_cuda::{CudaBuffer, CudaContext}; |
| 6 | +use processing_core::app_mut; |
| 7 | +use processing_core::error::{ProcessingError, Result}; |
| 8 | +use processing_render::graphics::view_target; |
| 9 | +use processing_render::image::{Image, gpu_image, pixel_size}; |
| 10 | + |
| 11 | +#[derive(Component)] |
| 12 | +pub struct CudaImageBuffer { |
| 13 | + pub buffer: CudaBuffer, |
| 14 | + pub width: u32, |
| 15 | + pub height: u32, |
| 16 | + pub texture_format: TextureFormat, |
| 17 | +} |
| 18 | + |
| 19 | +pub struct CudaPlugin; |
| 20 | + |
| 21 | +impl Plugin for CudaPlugin { |
| 22 | + fn build(&self, _app: &mut App) {} |
| 23 | + |
| 24 | + fn finish(&self, app: &mut App) { |
| 25 | + let render_app = app.sub_app(RenderApp); |
| 26 | + let render_device = render_app.world().resource::<RenderDevice>(); |
| 27 | + let wgpu_device = render_device.wgpu_device(); |
| 28 | + match CudaContext::new(wgpu_device, 0) { |
| 29 | + Ok(ctx) => { |
| 30 | + app.insert_resource(ctx); |
| 31 | + } |
| 32 | + Err(e) => { |
| 33 | + warn!("CUDA not available, GPU interop disabled: {e}"); |
| 34 | + } |
| 35 | + } |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +fn cuda_ctx(world: &World) -> Result<&CudaContext> { |
| 40 | + world |
| 41 | + .get_resource::<CudaContext>() |
| 42 | + .ok_or(ProcessingError::CudaError("CUDA not available".into())) |
| 43 | +} |
| 44 | + |
| 45 | +fn resolve_texture(app: &mut App, entity: Entity) -> Result<(Texture, TextureFormat, u32, u32)> { |
| 46 | + if app.world().get::<Image>(entity).is_some() { |
| 47 | + let texture = gpu_image(app, entity)?.texture.clone(); |
| 48 | + let p_image = app.world().get::<Image>(entity).unwrap(); |
| 49 | + return Ok((texture, p_image.texture_format, p_image.size.width, p_image.size.height)); |
| 50 | + } |
| 51 | + if let Ok(vt) = view_target(app, entity) { |
| 52 | + let texture = vt.main_texture().clone(); |
| 53 | + let fmt = vt.main_texture_format(); |
| 54 | + let size = texture.size(); |
| 55 | + return Ok((texture, fmt, size.width, size.height)); |
| 56 | + } |
| 57 | + Err(ProcessingError::ImageNotFound) |
| 58 | +} |
| 59 | + |
| 60 | +pub fn cuda_export(entity: Entity) -> Result<()> { |
| 61 | + app_mut(|app| { |
| 62 | + let (texture, texture_format, width, height) = resolve_texture(app, entity)?; |
| 63 | + |
| 64 | + let px_size = pixel_size(texture_format)?; |
| 65 | + let buffer_size = (width as u64) * (height as u64) * (px_size as u64); |
| 66 | + |
| 67 | + let existing = app.world().get::<CudaImageBuffer>(entity); |
| 68 | + let needs_alloc = existing.is_none_or(|buf| buf.buffer.size() != buffer_size); |
| 69 | + |
| 70 | + if needs_alloc { |
| 71 | + let cuda_ctx = cuda_ctx(app.world())?; |
| 72 | + let buffer = cuda_ctx |
| 73 | + .create_buffer(buffer_size) |
| 74 | + .map_err(|e| ProcessingError::CudaError(format!("Buffer creation failed: {e}")))?; |
| 75 | + app.world_mut().entity_mut(entity).insert(CudaImageBuffer { |
| 76 | + buffer, |
| 77 | + width, |
| 78 | + height, |
| 79 | + texture_format, |
| 80 | + }); |
| 81 | + } |
| 82 | + |
| 83 | + let world = app.world(); |
| 84 | + let cuda_buf = world.get::<CudaImageBuffer>(entity).unwrap(); |
| 85 | + let cuda_ctx = cuda_ctx(world)?; |
| 86 | + |
| 87 | + cuda_ctx |
| 88 | + .copy_texture_to_buffer(&texture, &cuda_buf.buffer, width, height, texture_format) |
| 89 | + .map_err(|e| { |
| 90 | + ProcessingError::CudaError(format!("Texture-to-buffer copy failed: {e}")) |
| 91 | + })?; |
| 92 | + |
| 93 | + Ok(()) |
| 94 | + }) |
| 95 | +} |
| 96 | + |
| 97 | +pub fn cuda_import(entity: Entity, src_device_ptr: u64, byte_size: u64) -> Result<()> { |
| 98 | + app_mut(|app| { |
| 99 | + let (texture, texture_format, width, height) = resolve_texture(app, entity)?; |
| 100 | + |
| 101 | + let existing = app.world().get::<CudaImageBuffer>(entity); |
| 102 | + let needs_alloc = existing.is_none_or(|buf| buf.buffer.size() != byte_size); |
| 103 | + |
| 104 | + if needs_alloc { |
| 105 | + let cuda_ctx = cuda_ctx(app.world())?; |
| 106 | + let buffer = cuda_ctx |
| 107 | + .create_buffer(byte_size) |
| 108 | + .map_err(|e| ProcessingError::CudaError(format!("Buffer creation failed: {e}")))?; |
| 109 | + app.world_mut().entity_mut(entity).insert(CudaImageBuffer { |
| 110 | + buffer, |
| 111 | + width, |
| 112 | + height, |
| 113 | + texture_format, |
| 114 | + }); |
| 115 | + } |
| 116 | + |
| 117 | + let world = app.world(); |
| 118 | + let cuda_buf = world.get::<CudaImageBuffer>(entity).unwrap(); |
| 119 | + let cuda_ctx = cuda_ctx(world)?; |
| 120 | + |
| 121 | + // wait for work (i.e. python) to be done with the buffer before we read from it |
| 122 | + cuda_ctx |
| 123 | + .synchronize() |
| 124 | + .map_err(|e| ProcessingError::CudaError(format!("synchronize failed: {e}")))?; |
| 125 | + |
| 126 | + cuda_buf |
| 127 | + .buffer |
| 128 | + .copy_from_device_ptr(src_device_ptr, byte_size) |
| 129 | + .map_err(|e| ProcessingError::CudaError(format!("memcpy_dtod failed: {e}")))?; |
| 130 | + |
| 131 | + cuda_ctx |
| 132 | + .copy_buffer_to_texture(&cuda_buf.buffer, &texture, width, height, texture_format) |
| 133 | + .map_err(|e| { |
| 134 | + ProcessingError::CudaError(format!("Buffer-to-texture copy failed: {e}")) |
| 135 | + })?; |
| 136 | + |
| 137 | + Ok(()) |
| 138 | + }) |
| 139 | +} |
| 140 | + |
| 141 | +pub fn cuda_write_back(entity: Entity) -> Result<()> { |
| 142 | + app_mut(|app| { |
| 143 | + let (texture, _, _, _) = resolve_texture(app, entity)?; |
| 144 | + |
| 145 | + let cuda_buf = app |
| 146 | + .world() |
| 147 | + .get::<CudaImageBuffer>(entity) |
| 148 | + .ok_or(ProcessingError::ImageNotFound)?; |
| 149 | + |
| 150 | + let cuda_ctx = cuda_ctx(app.world())?; |
| 151 | + |
| 152 | + cuda_ctx |
| 153 | + .copy_buffer_to_texture( |
| 154 | + &cuda_buf.buffer, |
| 155 | + &texture, |
| 156 | + cuda_buf.width, |
| 157 | + cuda_buf.height, |
| 158 | + cuda_buf.texture_format, |
| 159 | + ) |
| 160 | + .map_err(|e| { |
| 161 | + ProcessingError::CudaError(format!("Buffer-to-texture copy failed: {e}")) |
| 162 | + })?; |
| 163 | + |
| 164 | + Ok(()) |
| 165 | + }) |
| 166 | +} |
| 167 | + |
| 168 | +pub struct CudaBufferInfo { |
| 169 | + pub device_ptr: u64, |
| 170 | + pub width: u32, |
| 171 | + pub height: u32, |
| 172 | + pub texture_format: TextureFormat, |
| 173 | +} |
| 174 | + |
| 175 | +pub fn cuda_buffer(entity: Entity) -> Result<CudaBufferInfo> { |
| 176 | + app_mut(|app| { |
| 177 | + let cuda_buf = app |
| 178 | + .world() |
| 179 | + .get::<CudaImageBuffer>(entity) |
| 180 | + .ok_or(ProcessingError::ImageNotFound)?; |
| 181 | + Ok(CudaBufferInfo { |
| 182 | + device_ptr: cuda_buf.buffer.device_ptr(), |
| 183 | + width: cuda_buf.width, |
| 184 | + height: cuda_buf.height, |
| 185 | + texture_format: cuda_buf.texture_format, |
| 186 | + }) |
| 187 | + }) |
| 188 | +} |
| 189 | + |
| 190 | +pub fn typestr_for_format(format: TextureFormat) -> Result<&'static str> { |
| 191 | + match format { |
| 192 | + TextureFormat::Rgba8Unorm | TextureFormat::Rgba8UnormSrgb => Ok("|u1"), |
| 193 | + TextureFormat::Rgba16Float => Ok("<f2"), |
| 194 | + TextureFormat::Rgba32Float => Ok("<f4"), |
| 195 | + _ => Err(ProcessingError::UnsupportedTextureFormat), |
| 196 | + } |
| 197 | +} |
| 198 | + |
| 199 | +pub fn elem_size_for_typestr(typestr: &str) -> Result<usize> { |
| 200 | + match typestr { |
| 201 | + "|u1" => Ok(1), |
| 202 | + "<f2" => Ok(2), |
| 203 | + "<f4" => Ok(4), |
| 204 | + _ => Err(ProcessingError::CudaError(format!( |
| 205 | + "unsupported typestr: {typestr}" |
| 206 | + ))), |
| 207 | + } |
| 208 | +} |
0 commit comments