From 473e48114a7a05f69332e18fdea65db496f05489 Mon Sep 17 00:00:00 2001 From: chjchjchjchjchj Date: Wed, 20 Nov 2024 14:09:09 +0800 Subject: [PATCH] Fix: lang_adaptor --- models/rdt_runner.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/models/rdt_runner.py b/models/rdt_runner.py index 3c452e4..338b00e 100644 --- a/models/rdt_runner.py +++ b/models/rdt_runner.py @@ -22,6 +22,9 @@ def __init__(self, *, action_dim, pred_horizon, config, max_lang_cond_len, img_cond_len, lang_pos_embed_config=None, img_pos_embed_config=None, dtype=torch.bfloat16): super(RDTRunner, self).__init__() + + self.dtype = dtype + # Create diffusion model hidden_size = config['rdt']['hidden_size'] self.model = RDT( @@ -41,18 +44,21 @@ def __init__(self, *, action_dim, pred_horizon, config, self.lang_adaptor = self.build_condition_adapter( config['lang_adaptor'], in_features=lang_token_dim, - out_features=hidden_size + out_features=hidden_size, + dtype=self.dtype ) self.img_adaptor = self.build_condition_adapter( config['img_adaptor'], in_features=img_token_dim, - out_features=hidden_size + out_features=hidden_size, + dtype=self.dtype ) # A `state` refers to an action or a proprioception vector self.state_adaptor = self.build_condition_adapter( config['state_adaptor'], in_features=state_token_dim * 2, # state + state mask (indicator) - out_features=hidden_size + out_features=hidden_size, + dtype=self.dtype ) # Create the noise scheduler @@ -83,18 +89,18 @@ def __init__(self, *, action_dim, pred_horizon, config, [p.numel() for p in self.state_adaptor.parameters()])) def build_condition_adapter( - self, projector_type, in_features, out_features): + self, projector_type, in_features, out_features, dtype): projector = None if projector_type == 'linear': - projector = nn.Linear(in_features, out_features) + projector = nn.Linear(in_features, out_features, dtype=dtype) else: mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) if mlp_gelu_match: mlp_depth = int(mlp_gelu_match.group(1)) - modules = [nn.Linear(in_features, out_features)] + modules = [nn.Linear(in_features, out_features, dtype=dtype)] for _ in range(1, mlp_depth): modules.append(nn.GELU(approximate="tanh")) - modules.append(nn.Linear(out_features, out_features)) + modules.append(nn.Linear(out_features, out_features, dtype=dtype)) projector = nn.Sequential(*modules) if projector is None: