万字长文 - 全网最全Pi0代码详解

Xbot具身知识库 2025-10-31 17:32
万字长文 - 全网最全Pi0代码详解图1

源码来自@董子斌 https://github.com/ZibinDong/openpi_pytorch

从demo理解输入和输出

# 输入
observation = {
    "image": {
        "base_0_rgb": torch.randint(
            0256, (13224224), dtype=torch.uint8, device=device
        ),
        # "left_wrist_0_rgb": ...,   Suppose we don't have this view
        # "right_wrist_0_rgb": ...,  Suppose we don't have this view
    },
    "state": torch.randn(18, device=device) * 0.2,
    "prompt": ["do something"],
}

# 运行模型和action输出
action = policy.select_action(observation)[0, :, :7]

算法输入

观测信息observation包含三部分:图像、机器人状态向量、语言指令。

图像信息(observation["image"]

  • 支持三路相机输入

    可以只输入部分,缺失的会自动用掩码补齐

    • base_0_rgb(主相机)
    • left_wrist_0_rgb(左侧腕部相机)
    • right_wrist_0_rgb (右侧腕部相机)
  • 每张输入图片的格式为uint8、范围 [0,255],形状 (*b, 3, H, W)

机器人状态 observation["state"]

  • float32,形状 (*b, s),表示机器人的本体状态(例如关节角),最多32维度

语言指令 observation["prompt"]

算法输出

输出格式和维度

  • 针对返回值actions,其形状为(*b, config.n_action_steps, config.max_action_dim)

  • 在默认情况下,

    • n_action_steps = 50(输出一个“动作块/chunk”,又叫“动作地平线 ”)
    • max_action_dim = 32(短动作向量会右侧 pad 到 32)。

针对代码

action = policy.select_action(observation)[0, :, :7]

表示取出第 1 个 batch全部 50 个时间步,并截取前 7 个动作维度

动作维度和机器人怎么对应?

PI0 是“跨机型/跨体态”的统一策略,训练时把不同机器人统一到同一上限维度,并对不足的维度做零填充(state/action 都一样)。例如论文里总结了典型平台的“配置/动作维度”:

这些都统一 pad 到数据集最大维度(论文示例最大 18,代码实现允许到 32 作为上限)。使用时你按自己的机器人有效 DoF 截取前action_dim即可(比如 UR5e 取前 7 维)。

为什么代码里面是:7

因为示例里假设是 UR5e(6 关节 + 夹爪 = 7),于是 :7 正好取有效部分;后面的列是 padding。

    • 单臂:UR5e(7 维:6 关节 + 1 夹爪);Franka(8 维:7 关节 + 1 夹爪)
    • 双臂:常见是 14 维(两只 6DoF 手臂 + 两个夹爪)
    • 移动操作臂:在 14 维基础上再加底盘控制维度(非全向底盘 +2,自由底盘 +3),所以 16 或 17 维

输出action的物理含义

  • 关节(Arm Joints):

在 PI0Config 里有配置信息:

默认输出的关节部分就是 绝对值目标(absolute joint positions),如果在 ALOHA 适配场景中把 use_delta_joint_actions_aloha=True,则关节动作会变成相对于当前状态的增量 Δq。夹爪保持绝对值。

    • adapt_to_pi_aloha: bool = False
    • use_delta_joint_actions_aloha: bool = False
  • 夹爪(Gripper):

源码里明确注释:

“Gripper dimensions will remain in absolute values.”

所以无论 use_delta_joint_actions_aloha 是否开启,夹爪维度都是 绝对值

  • 移动底盘(Mobile Base) :

移动底盘采用 twist 指令,即:

twist 指令的坐标系为车身坐标系(body frame)

“These actions are in the robot’s local body frame.”

    • 如果机器人有非全向底盘,则动作维度中包含  → 平移速度、角速度

    • 如果是全向底盘,则包含 

action输出到实际控制

  • 控制频率:

论文明确:PI0 能以最高 50 Hz 控制执行(例如执行高灵巧操作)。也有平台以 20 Hz 运行(UR5e/Franka)

  • 推理 / 执行的节奏(关键点)

开环地顺序执行动作块(作者尝试过块间“集成/平滑”策略,会变差,遂放弃)!

论文在附录 D 给了实际节奏:

    • 20 Hz 机器人:每 0.8 s 重新前向一次(也就是先执行 16 步,再重算下一个动作块)
    • 50 Hz 机器人:每 0.5 s 前向一次(执行 25 步再重算)。
  1. 一次前向会生成一个50步的“动作块”(n_action_steps = 50
  2. 按给定的控制频率逐步执行这些动作(例如 50 Hz 就每 20 ms 执行下一步)

多模态数据的准备

policy.select_action(observation)的具体实现如下:

classPI0Policy(PreTrainedPolicy):  

 @torch.no_grad
    defselect_action(
        self, observation: dict[str, Tensor], noise: Tensor | None = None
    )
:

        """
        Observation: {
            "image": {
                "base_0_rgb": (*b, c, h, w),  # uint8 [0, 255]
                ...
            },
            "state": float32 [*b, s],
            "prompt": List[str],

            "lang_tokens": float32 [*b, l],
            "lang_masks": float32 [*b, l],
        }
        either provide `prompt` or (`lang_tokens`, `lang_masks`).
        """

        self.eval()

        images, img_masks = self.prepare_images(observation)
        state = self.prepare_state(observation)
        lang_tokens, lang_masks = self.prepare_language(observation)
        actions = self.model.sample_actions(
            images, img_masks, lang_tokens, lang_masks, state, noise=noise
        )
        return actions

针对图像的预处理prepare_images

defprepare_images(self, observation: dict[str, Tensor]):
    """Normalize, resize, and pad images and stack them into a tensor.
    
    Args:
        observation (dict[str, Tensor])

    Returns:
        images (torch.Tensor): (*b, n, c, h, w) images in range [-1.0, 1.0]
        img_masks (torch.Tensor): (*b, n) masks for images, True if image is present, False if missing
    """
    dtype = observation["state"].dtype  # 获取状态的 dtype,通常是 float32    bsize = observation["state"].shape[0]  # 获取批大小(batch size)    images, img_masks = [], []  # 用于存储图像和掩码    # 确定哪些图像存在    present_img_keys = [key for key in IMAGE_KEYS if key in observation["image"]]    missing_img_keys = [key for key in IMAGE_KEYS if key notin present_img_keys]    # 处理存在的图像    for key in present_img_keys:        img = observation["image"][key]  # 获取图像        img = img.to(dtype) / 127.5 - 1.0# 归一化到 [-1, 1] 的范围        img = resize_with_pad(  # 调整图像大小并填充            img, *self.config.resize_imgs_with_padding, pad_value=-1.0        )        images.append(img)  # 存储处理过的图像        img_masks.append(torch.ones((bsize,), dtype=torch.bool, device=img.device))  # 对应掩码为 True    # 处理缺失的图像(填充)    for key in missing_img_keys:        img = torch.full_like(img, fill_value=-1.0)  # 填充 -1.0        images.append(img)  # 存储填充图像        img_masks.append(torch.zeros((bsize,), dtype=torch.bool, device=img.device))  # 对应掩码为 False    # 堆叠图像和掩码    images = torch.stack(images, dim=1)  # (*b, n, c, h, w)    img_masks = torch.stack(img_masks, dim=1)  # (*b, n)    return images, img_masks

Step 1:获取图像键、batch 信息与容器初始化

dtype = observation["state"].dtype      # 一般是 torch.float32
bsize = observation["state"].shape[0]   # batch_size (b)
images, img_masks = [], []              # 收集每个相机的图像与掩码

# 支持的相机顺序(决定输出的相机维度顺序)
IMAGE_KEYS = ("base_0_rgb""left_wrist_0_rgb""right_wrist_0_rgb")

present_img_keys = [key for key in IMAGE_KEYS if key in observation["image"]]
missing_img_keys = [key for key in IMAGE_KEYS if key notin present_img_keys]
  • dtype:用 state 的 dtype,把图像转成相同精度,方便后续拼接/前向一致(通常 float32)。
  • bsizebatch 大小b,后面所有图像、掩码都会以 b 为第一维。
  • IMAGE_KEYS:支持的相机视角("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
  • present_img_keys实际存在的相机键列表。
  • missing_img_keys缺失的相机键列表(后面会用占位图和 False 掩码补齐)。

Step 2:图像预处理

# 处理存在的图像
for key in present_img_keys:    img = observation["image"][key]           # (b, 3, H, W),uint8,范围 [0, 255]    img = img.to(dtype) / 127.5 - 1.0         # (b, 3, H, W),float,范围 [-1, 1]    img = resize_with_pad(                    # 保持比例 resize 到 224,并用 -1.0 pad 到 (224, 224)        img, *self.config.resize_imgs_with_padding, pad_value=-1.0    )                                         # (b, 3, 224, 224)    images.append(img)                        # 先以“列表”形式收集 (b, 3, 224, 224)    img_masks.append(        torch.ones((bsize,), dtype=torch.bool, device=img.device)    )                                         # (b,) = 全 True(这一路相机存在)# 处理缺失的图像(填充)for key in missing_img_keys:    img = torch.full_like(img, fill_value=-1.0)  # 填充 -1.0    images.append(img)  # 存储填充图像    img_masks.append(torch.zeros((bsize,), dtype=torch.bool, device=img.device))  # 对应掩码为 False

  1. 归一化:把图像像素值从 [0, 255] 转换为 [-1, 1]
  2. resize 和 padding:将图像调整为指定的尺寸(在配置中为 (224, 224))并进行填充。填充是为了统一图像大小,保持输入张量的一致性。如果某些图像没有数据,预设会用值为 -1.0 的填充值填充。
  3. 暂存imagesimg_masks(到下一个Step进行讲解)

Step3:图像和掩码堆叠

# 堆叠图像和掩码
images = torch.stack(images, dim=1)  # (*b, n, c, h, w)
img_masks = torch.stack(img_masks, dim=1)  # (*b, n)

图像堆叠(images):

将所有图像堆叠成一个五维张量,形状是 (b, n, 3, 224, 224),其中:

  • b 是批大小(batch size),
  • n 是有效的图像数量(可能是 1, 2 或 3),
  • 3 是通道数(RGB图像),
  • 224 和 224 是图像的高度和宽度。

掩码堆叠(img_masks

对应的掩码张量 img_masks 也会堆叠,形状是 (b, n),表示每张图像是否有效

  • 如果图像存在,掩码的值是 True,表示图像有效。
  • 如果图像缺失,掩码的值是 False,表示图像缺失或无效。

这里给出当有效图像只有"base_0_rgb", "left_wrist_0_rgb"时的例子

  • present_img_keys["base_0_rgb", "left_wrist_0_rgb"]
  • missing_img_keys["right_wrist_0_rgb"]
  • 下图给出image_masks的堆叠情况
  • 针对images的堆叠,维度和下图类似,每个10的位置一张图片矩阵或由-1填充得到的矩阵
万字长文 - 全网最全Pi0代码详解图2

针对机器人状态的预处理prepare_state

defprepare_state(self, observation: dict[str, Tensor]):
    """Pad the state to the maximum state dimension.

    Args:
        observation (dict[str, Tensor])

    Returns:
        state (torch.Tensor): (*b, max_state_dim) padded state tensor
    """

    state = observation["state"]
    state = F.pad(state, (0, self.config.max_state_dim - state.shape[1]))
    return state

把 observation["state"]从原始维度(b, s)右侧补零到 (b, max_state_dim),使每个样本的状态向量长度一致,便于后续喂给固定输入宽度的线性层。

torch.nn.functional.pad(input, pad, mode='constant', value=0)

pad 参数是一个元组,从 最后一维开始,成对指定 (left, right)  的填充数量。

  • pad 是一个元组,从最后一维开始,成对指定 (pad_left, pad_right)
  • 如果张量是多维的,就继续往前一维一维写。例如 2D 就需要两个维度的 pad(即 4 个数),3D 需要 6 个数,以此类推。
  • 默认填充值是 0,可以通过 value=... 修改。

Example:

x = torch.tensor([1,2,3]) # pad=(l, r), 在左边补 l 个数, 在右边补 r 个数
F.pad(x, (2,1))      # → [0,0,1,2,3,0]
x = torch.arange(6).reshape(2,3)   # shape (2,3)
# [[0,1,2],
#  [3,4,5]]

y = F.pad(x, (1,22,1))  # pad=(w_left, w_right, h_left, h_right)
# w_left=1, w_right=2, h_left=2, h_right=1
# 输出形状 (2+2+1, 3+1+2) = (5,6)
x = torch.zeros(1,2,3)  # shape (C=1,H=2,W=3)
# pad=(w_left, w_right, h_left, h_right, c_left, c_right)
y = F.pad(x, (1,12,00,2)) 
# W 左右各补1 → W=5
# H 上补2下补0 → H=4
# C 前补0后补2 → C=3
# y.shape == (3,4,5)

针对任务指令的预处理prepare_language

defprepare_language(self, observation: dict[str, Tensor]):
    """
    返回:
        lang_tokens: (*b, l)  # int 型 token ids
        lang_masks : (*b, l)  # bool 掩码, True=该位置有有效 token, False=padding
    """
    lang_tokens = observation.get("lang_tokens"None)    lang_masks  = observation.get("lang_masks"None)    prompt      = observation.get("prompt"None)    # 必须二选一: prompt 或 (lang_tokens, lang_masks)    if prompt isNoneand (lang_tokens isNoneor lang_masks isNone):        raise ValueError(...)    device = observation["state"].device    if prompt isnotNoneand (lang_tokens isNoneor lang_masks isNone):        # 1) 规范化文本到 PaliGemma 期望格式        prompt = [p if p.startswith(""elsef"{p}"for p in prompt]        prompt = [p if p.endswith("\n"elsef"{p}\n"for p in prompt]        # 2) 调 tokenizer -> 得到 (b, L) 的 ids 与 mask        tokenized_prompt = self.language_tokenizer.__call__(            prompt,            padding="max_length",            padding_side="right",            max_length=self.config.tokenizer_max_length,            return_tensors="pt",        )        lang_tokens = tokenized_prompt["input_ids"].to(device=device)        lang_masks  = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)    else:        # 直接使用用户提供的 token/mask        lang_tokens = observation["lang_tokens"].to(device=device)        lang_masks  = observation["lang_masks"].to(device=device, dtype=torch.bool)    return lang_tokens, lang_masks

输入模式与选择逻辑

  • 可以传 原始文本prompt: List[str] ,也可以直接传 已分好词的lang_tokenslang_masks(两者必须成对)。

  • 批次大小 b 来自 observation["state"].shape[0]。无论是 prompt 还是 lang_tokens/lang_masks第一维都要等于b

使用 prompt 的处理流程(最常见)

Step1:规范化到 PaliGemma 前缀格式

prompt = [p if p.startswith("") else f"{p}" for p in prompt]
prompt = [p if p.endswith("\n") else f"{p}\n" for p in prompt]
  • 在开头加 ,在结尾加换行 \n(PaliGemma 以 \n 作为分隔  的约定)。
  • 如果原文本已经有  或已以 \n 结尾,就不会重复添加。

Step2:tokenizer

tokenized_prompt = self.language_tokenizer.__call__(
    prompt,
    padding="max_length",         # 不足 L 的右侧 pad 到 L
    padding_side="right",
    max_length=self.config.tokenizer_max_length,  # L,默认 48
    return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)        # 形状: (b, L), dtype=torch.long
lang_masks  = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)  # 形状: (b, L), True=非pad

矩阵维度

  • lang_tokens(b, L),整型 token id(通常是 torch.long

  • lang_masks : (b, L),布尔掩码;

    • True 表示该位置是有效 token(包括  与正文)
    • False 表示 padding。

padding:

  • 右侧补齐(padding_side="right"
  • 上面没有显式写 truncation=True。大多数 HF tokenizer 在 max_length+padding="max_length" 情况下会对超长序列进行截断(同时给出 warning)。

使用 lang_tokens/lang_masks 的处理流程

lang_tokens = observation["lang_tokens"].to(device=device)            # (b, L)
lang_masks  = observation["lang_masks"].to(device=device, dtype=torch.bool)  # (b, L)
  • 传入的tokens/masks 将被直接使用,不会再做 /\n 之类的改写,也不会再 padding。

  • 请务必保证:

  1. 形状第一维等于 batch 大小 b
  2. 第二维 L 不应超过 config.tokenizer_max_length(否则后续拼接位置可能不对齐);
  3. lang_tokens.dtype 应为 整型(通常 torch.long),因为后续会走 embedding;
  4. lang_masks.dtype 为 bool,True=非 pad,False=pad。

信息嵌入

把“图像 + 语言提示”当作前缀(prefix) ,把“状态/动作/时间步”等控制相关的 token 当作后缀(suffix)

前缀信息嵌入 embed_prefix

总览

   defembed_prefix(
        self, images, img_masks, lang_tokens, lang_masks
    )
 -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        """Embed images with SigLIP and language tokens with embedding layer to prepare
        for PaliGemma transformer processing.

        Args:
            images (torch.Tensor):    float (*b, n, c, h, w) images in range [-1.0, 1.0]
            img_masks (torch.Tensor):  bool (*b, n) masks for images
            lang_tokens (torch.Tensor): int (*b, l) language tokens
            lang_masks (torch.Tensor): bool (*b, l) masks for language tokens
        """
        bsize = images.shape[0]        device = images.device        dtype = images.dtype        # embed image        images = einops.rearrange(images, "b n c h w -> (b n) c h w")        img_emb = self.paligemma_with_expert.embed_image(images)        num_patch = img_emb.shape[1]        img_emb = einops.rearrange(img_emb, "(b n) l d -> b (n l) d", b=bsize)        img_emb = img_emb.to(dtype=dtype) * (img_emb.shape[-1] ** 0.5)        num_img_embs = img_emb.shape[1]        img_masks = einops.repeat(img_masks, "b n -> b (n l)", l=num_patch)        # embed language        lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)        num_lang_embs = lang_emb.shape[1]        lang_emb = lang_emb.to(dtype=dtype) * np.sqrt(lang_emb.shape[-1])        # assemble embeddings        embs = torch.cat([img_emb, lang_emb], dim=1)        pad_masks = torch.cat([img_masks, lang_masks], dim=1)        # PaliGemma uses bidirectional attention for prefix tokens,        # so we set 1D `att_masks` to zeros.        # (see `make_att_2d_masks` to understand why zeros means bidirection)        att_masks = torch.zeros(            (bsize, num_img_embs + num_lang_embs), device=device, dtype=torch.bool        )        return embs, pad_masks, att_masks

embed_prefix(...) 的职责就是把前缀两种模态(图像、文本)各自编码成同一维度的 token 序列,然后沿着序列维拼起来,并产出两类 mask:

  • pad_masks:哪一些位置是真实 token、哪一些是 padding(缺失图像或被 tokenizer pad 的 token)。
  • att_masks(1D) :前缀 token 的注意力“类型”标记(这里全部设为 0),供后续 make_att_2d_masks 展开为真正的二维注意力矩阵时使用。对前缀,含义是“双向注意力”。

为什么前缀要双向?
PaliGemma 的前缀部分(图像 patch + 文本 token)用双向自注意力来做跨模态对齐/融合,这样所有前缀 token 能互相看见,利于把视觉上下文和语言提示融合后,再把“控制后缀”接进来驱动动作专家(Expert Gemma)。

Step1:图像编码(SigLIP 视觉编码器 → patch token)

images = einops.rearrange(images, "b n c h w -> (b n) c h w")
  • 输入 images 形状是 (*b, n, c, h, w)n 是视角数(top、left_wrist、right_wrist 等)。
  • 这里把 b 和 n 合起来,变成 (b*n, c, h, w),这样可以直接喂给视觉编码器做逐图前向。

img_emb = self.paligemma_with_expert.embed_image(images)

调 SigLIP/ViT-style 的视觉编码:输出形状通常是 ((b*n), L, D)

  • L:每张图切成多少个 patch,再加上可能的特殊 token(具体是否有 CLS 以实现为准;从后面逻辑看,它把所有长度叫 num_patch 用于 mask 展开)。对于 224×224、patch 16,L ≈ 14×14=196(若有 CLS 就 197)。
  • D:视觉嵌入维度(与 PaliGemma 主干的隐藏维一致或经投影一致,代码里稍后只用 shape[-1]

num_patch = img_emb.shape[1]
# 再把 (b*n) 还原成 b,并把 “n 张图的 L 个 patch”摊平成一段连续序列
img_emb = einops.rearrange(img_emb, "(b n) l d -> b (n l) d", b=bsize)

还原后 img_emb 形状是 (b, n*L, D):也就是把多个相机的所有 patch 串接成“一段前缀图像 token 序列”。


# 统一 dtype(比如到 fp16/bf16),并做 “√D 放缩”
img_emb = img_emb.to(dtype=dtype) * (img_emb.shape[-1] ** 0.5)
num_img_embs = img_emb.shape[1]
  • num_img_embs = n*L,就是图像前缀 token 的总长度。

为什么乘√D
Transformer 的经典做法(源自 Vaswani et al.)是把词嵌入乘以 √d_model,以便和位置编码的量级匹配、并在层归一化前维持稳定的方差尺度。此处把视觉 patch 嵌入也乘 √D,让“图像 token”和后面“文本 token”在数值尺度上对齐,避免某一模态数值太小/太大导致注意力分布偏置。


# 把每张图的存在性 mask,从 (b, n) 展开到 (b, n*L)
img_masks = einops.repeat(img_masks, "b n -> b (n l)", l=num_patch)

  • 先前 prepare_images 里,如果某个视角缺失,会放一张全 -1 的占位图,同时 img_masks 里该视角是 False。

  • 这里展开后,每张图的 每个 patch 都继承该视角的 True/False。

  • 这个 mask 之后用于:

  1. 构造二维注意力 mask 时屏蔽掉“来自缺失视角的 patch token”;
  2. 计算 position_idsposition_ids = cumsum(pad_masks)-1),使 padding 的位置得到 -1(无效)。

为什么要扩维度?

einops.repeat(img_masks, "b n -> b (n l)", l=num_patch) 的作用就是 把每个相机的存在性掩码扩展到该相机的所有 patch token 上,保证和图像嵌入序列长度一致,从而正确屏蔽缺失视角的所有 patch。

原始形状:

  • img_masks 在进入 embed_prefix 前是 (b, n)

    • b = batch size
    • n = 视角数(top camera, left wrist, right wrist …)
      每个位置只表示“这一张图是否存在”。
  • images在进入 embed_prefix 前是 (b, n, c, h, w)

  • images经过视觉编码器,每张图不再是一个整体,而是被切成 num_patch 个 patch(token),每个token向量的长度是。所以图像嵌入 img_emb 的形状是 (b, n * num_patch, d)

扩展原因:

因为 img_emb 有 n * num_patch个 token,而 img_masks 只有 n 个标志位。如果我们保持 img_masks 不变((b, n)),就没法和 img_emb 对齐。

所以要把 img_masks广播/扩展,让它为每个 patch 也提供一个 mask 值:

img_masks = einops.repeat(img_masks, "b n -> b (n l)", l=num_patch)

新形状是 (b, n * num_patch),正好能和 img_emb 的序列长度对齐。

举个例子:

假设:

  • batch size b=1
  • 只有 2 个相机 n=2
  • 每张图切成 num_patch=3

那么:

  • img_masks 原始:[[1, 0]] (第一个相机存在,第二个相机缺失)

  • 扩展后:[[1, 1, 1, 0, 0, 0]]

    • 前 3 个 patch 来自第一个相机 → 有效
    • 后 3 个 patch 来自第二个相机 → 全部 mask 掉

这样,后续 Transformer 在做注意力时,就不会“看见”缺失相机的 patch。

Step2:文本编码(Gemma 词嵌入层)

lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
num_lang_embs = lang_emb.shape[1]
lang_emb = lang_emb.to(dtype=dtype) * np.sqrt(lang_emb.shape[-1])
  • embed_language_tokens 一般是查词表 + 线性投影或直接查表,输出 (b, l, d)

img_emb 的最后一维度lang_emb 的最后一维度是一样的。因为这两个模态(图像 patch token 和语言 token)最后要 拼接在一起 形成一条连续的前缀序列,然后送进同一个 PaliGemma Transformer。Transformer 的输入必须在最后一维(hidden size)对齐,否则没法一起做注意力和线性变换。

  • 同样做了 乘 √D 的放缩,理由与图像相同:把语言 token 的数值尺度和视觉 token 对齐,利于后续把两段序列拼接、共享同一个 Transformer 主干(PaliGemma)。

拼接序列 + 组合 mask

embs = torch.cat([img_emb, lang_emb], dim=1)   # (b, 图像个数*每个图像的token数 + 语言token数, D)
pad_masks = torch.cat([img_masks, lang_masks], dim=1)  # (b, n*L + l)
  • 现在前缀序列是:“所有图像 patch token 在前语言 token 在后”。
  • 这一步确保两个模态统一在一条时间轴/位置轴上,方便后面一口气喂进 PaliGemma 的 Transformer。

生成前缀的 1D 注意力类型 mask(全 0 = 双向)

# 前缀使用双向注意力:把 1D att_masks 全部设为 0
# (make_att_2d_masks 里会把 0 解读为“bidirectional 前缀”,生成全互看的注意力矩阵)
att_masks = torch.zeros(
    (bsize, num_img_embs + num_lang_embs), device=device, dtype=torch.bool
)
return embs, pad_masks, att_masks
  • 这里的 att_masks不是“padding mask”,也不是最终的二维注意力 mask,而是一条 “每个 token 的注意力类型标记” 。
  • 在后续的代码中,make_att_2d_masks会生成2D注意力矩阵

尽管att_masks 是2D矩阵,但由于第一维是batch,那么,对于一个 batch,att_masks[i] 就是一行向量,长度等于序列长度 seq_len。即,从“单条序列”的角度看,它就是 1D 的向量。

后缀信息嵌入embed_suffix

defembed_suffix(self, state, noisy_actions, timestep):
    """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.

    Args:
        state (torch.Tensor):         float32 (*b, s) robot state
        noisy_actions (torch.Tensor): float32 (*b, n, m) noisy actions
        timestep (torch.Tensor):      float32 (*b,) timestep in [0, 1] range
    """
    bsize = state.shape[0]    device = state.device    dtype = state.dtype    # embed state    state_emb = self.state_proj(state)    # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]    time_emb = create_sinusoidal_pos_embedding(        timestep,        self.config.proj_width,        min_period=4e-3,        max_period=4.0,        device=device,    )    time_emb = time_emb.type(dtype=dtype)    # Fuse timestep + action information using an MLP    action_emb = self.action_in_proj(noisy_actions)    time_emb = einops.repeat(time_emb, "b d -> b n d", n=action_emb.shape[1])    action_time_emb = torch.cat([action_emb, time_emb], dim=-1)    action_time_emb = self.action_time_mlp_in(action_time_emb)    action_time_emb = F.silu(action_time_emb)  # swish == silu    action_time_emb = self.action_time_mlp_out(action_time_emb)    action_time_dim = action_time_emb.shape[1]    # Add to input tokens    embs = torch.cat([state_emb[:, None], action_time_emb], dim=1)    pad_masks = torch.ones(        (bsize, action_time_dim + 1), device=device, dtype=torch.bool    )    # Set attention masks for suffix tokens so that prefix tokens cannot attend to suffix tokens.    # And state token cannot attend action tokens.    # Action tokens use a bidirectional attention.    att_masks = torch.zeros(        (bsize, action_time_dim + 1), device=device, dtype=torch.bool    )    att_masks[:, :2] = True    return embs, pad_masks, att_masks

noisy_actions 的维度结构

noisy_actions.shape = (batch_size, n_action_steps, max_action_dim)

各维度含义

  • 第0维 (batch_size) :批次大小,通常为1或更多
  • 第1维 (n_action_steps) :动作序列长度,默认为50步
  • 第2维 (max_action_dim) :动作的最大维度,为了支持不同机器人而设置为18

两个时间概念

Flow Matching中的

这是生成过程的数学参数,不是物理时间,可以将其理解为去噪进度

  • t=0:完全的噪声数据
  • t=1:完全的真实动作序列
  • t=0.3:30%真实动作 + 70%噪声的混合

这个过程发生在模型推理时,不是机器人执行时

机器人的实际执行时间

  • 物理时刻: 执行动作
  • 物理时刻: 执行动作
  • 物理时刻2: 执行动作
  • ...
  • 物理时刻: 执行动作

这些是真实世界的50个不同时刻,与flow matching的t无关。

Step1:状态嵌入

state_emb = self.state_proj(state)
  • self.state_proj 是一个线性层:nn.Linear(self.config.max_state_dim, self.config.proj_width)
  • 将机器人状态从原始维度(如关节角度)映射到transformer的嵌入维度
  • 例如:7维关节角度(已经在prepare_state中扩充为max_state_dim维度) → 1024维嵌入向量

Step2:时间步嵌入

time_emb = create_sinusoidal_pos_embedding(
    timestep,     # flowmathing里面的时间
    self.config.proj_width,
    min_period=4e-3,
    max_period=4.0,
    device=device,
)
time_emb = time_emb.type(dtype=dtype)

调用 create_sinusoidal_pos_embedding 函数,完成时间步嵌入

defcreate_sinusoidal_pos_embedding(
    time: torch.tensor,
    dimension: int,
    min_period: float,
    max_period: float,
    device="cpu",
)
 -> Tensor:

    """Computes sine-cosine positional embedding vectors for scalar positions."""    if dimension % 2 != 0:        raise ValueError(f"dimension ({dimension}) must be divisible by 2")    if time.ndim != 1:        raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")    fraction = torch.linspace(        0.01.0, dimension // 2, dtype=torch.float32, device=device    )    period = min_period * (max_period / min_period) ** fraction    # Compute the outer product    scaling_factor = 1.0 / period * 2 * math.pi    sin_input = scaling_factor[None, :] * time[:, None]    pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)    return pos_emb

数学公式

公式推导
  • 批大小为 ,输出维度为 (要求  为偶数),令 

  • 设最小周期 ,最大周期 ,令 

  • 当  时,频率在对数尺度上线性分布:

    对 ,定义

    当  时,取唯一的周期

  • 对应的角频率为

  • 给定一个标量时间/位置 ,定义正弦-余弦位置编码为

  • 给定一组时间 

    这里的  指的是每个batch 的时间步值,也就是说没有关系

    则外积

    最终的嵌入矩阵

    其中“|”表示按列拼接。

    频率范围端点(当 ):

    且在对数尺度上线性均匀覆盖区间 

为什么这样设计
  1. 多尺度覆盖(log 均匀频率)

    令周期按指数(几何)级数分布,等价于频率在对数尺度上线性分布,可同时覆盖低频(大周期)与高频(小周期)模式,捕捉从缓慢变化到快速变化的多尺度信息。

  2. 使用的原因

    周期为  的正弦波应写为 。写成  可直接用角频率,单位一致并且语义直观: 就是“实际周期”。

  3. 成对的 sin 和 cos

    对任意位移 Δ,

    即平移对应于每个频率通道上的线性旋转变换,这使得模型更容易从绝对位置泛化到相对位移。

  4. 内积只与相对位置有关(对每个频率对)

    对同一频率通道,利用恒等式 

  5. 整体嵌入的内积成为多频率余弦核的和:

    只依赖差值 ,有利于捕捉相对位置信息。

  6. 维度必须为偶数

    每个频率占用两个维度(sin 与 cos),故  需为偶数,

公式和代码的逐行对应

1.维度与形状检查

if dimension % 2 != 0:
    raise ValueError(f"dimension ({dimension}) must be divisible by 2")

if time.ndim != 1:
    raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")

  • 要求  为偶数,即
  • 期望  的形状是 ,即 

  1. 对数线性采样周期(或对数线性采样频率)
fraction = torch.linspace(
    0.01.0, dimension // 2, dtype=torch.float32, device=device
)
period = min_period * (max_period / min_period) ** fraction

3.角频率计算

scaling_factor = 1.0 / period * 2 * math.pi

  1. 外积(批量化)与逐元素非线性
sin_input = scaling_factor[None, :] * time[:, None]
索引形状与广播乘法
  • scaling_factor 原形状为 。加上 None(等价于 unsqueeze(0))后变成 ,即一行 K 列
  • time 原形状为 。用 [:, None](等价于 unsqueeze(-1))后变成 ,即 B 行一列。
  • 按 PyTorch 的广播规则: 与  相乘,会在维度为 1 的位置进行扩展,结果形状为 
  • 任意位置  的元素为,这正是外积  的元素定义

等效的其他写法

  • sin_input = time[:,None]∗scaling_factor[None,:]
  • sin_input=time[:,None] @ scaling_factor[None,:]
  • sin_input=torch.outer(time, scaling_factor)
torch.sin(sin_input), torch.cos(sin_input)
  • 对  的每个元素施加 sin、cos

  1. 拼接得到最终嵌入
pos_emb = torch.cat([sin(...), cos(...)], dim=1)

示例

  1. 选取

  2. 参数设定

  • 维度  → 频率通道数 

  • 周期按几何级数:

  • 角频率:

  • 取批量时间向量:

  • 计算过程

  • 先算外积相位矩阵

  • 对  逐元素取正弦与余弦,并按列拼接得到最终嵌入:

4.验证代码

import math
import torchfrom torch import Tensor# 你的函数(原样拷贝)defcreate_sinusoidal_pos_embedding(
    time: torch.tensor,
    dimension: int,
    min_period: float,
    max_period: float,
    device="cpu",
)
 -> Tensor:
    """Computes sine-cosine positional embedding vectors for scalar positions."""    if dimension % 2 != 0:        raise ValueError(f"dimension ({dimension}) must be divisible by 2")    if time.ndim != 1:        raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")    fraction = torch.linspace(        0.01.0, dimension // 2, dtype=torch.float32, device=device    )    period = min_period * (max_period / min_period) ** fraction    scaling_factor = 1.0 / period * 2 * math.pi    sin_input = scaling_factor[None, :] * time[:, None]    pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)    return pos_emb# 例子 A 的参数与 tD = 8min_period = 0.125max_period = 1.0t = torch.tensor([0.00.250.50.751.0], dtype=torch.float32)# 用你的函数计算pos_code = create_sinusoidal_pos_embedding(    time=t, dimension=D, min_period=min_period, max_period=max_period, device="cpu")# 用“闭式公式矩阵”(上面手算得到的矩阵)作为期望结果pos_formula_closed = torch.tensor([    [000,  0,  1,  1,  1,  1],    [000,  1,  1,  1-1,  0],    [000,  0,  1,  1,  1-1],    [000-1,  1,  1-1,  0],    [000,  0,  1,  1,  1,  1],], dtype=torch.float32)# 校验(允许极小的浮点误差)ok = torch.allclose(pos_code, pos_formula_closed, rtol=1e-3, atol=1e-3)max_abs_diff = (pos_code - pos_formula_closed).abs().max().item()print("allclose:", ok)print("max_abs_diff:", max_abs_diff)print("pos_code:\n", pos_code)

Step3:动作嵌入

action_emb = self.action_in_proj(noisy_actions)
  • action_in_proj是一个线性投影层 nn.Linear(self.config.max_action_dim, self.config.proj_width)

proj_width是啥?

  • 从论文附录B可以看到,action expert使用较小的嵌入维度(嵌入向量的特征维度,即每个嵌入向量的长度)以加速推理:

  • proj_width=1024,这比主VLM backbone(2048)小,用于平衡性能和速度

    主VLM backbone使用预训练的PaliGemma模型,其嵌入维度已经固定

    self.paligemma_with_expert = PaliGemmaWithExpertModel(
        paligemma_with_export_config
    )

Step4:时间嵌入扩展

time_emb = einops.repeat(time_emb, "b d -> b n d", n=action_emb.shape[1])
  • time_emb原始维度:(batch_size, proj_width)
  • n=action_emb.shape[1]:即n_action_steps=50
  • einops.repeat:将时间嵌入在第1维重复n次

详细理解:

50个动作共享一个去噪时间步

在flow matching中,时间步  是全局的,但每个动作token都需要知道当前的"去噪进度"

怎么理解?

time_step = 0.3  # 全局去噪进度30%

time_emb = repeat(time_step_embedding, 50)  # 复制给50个token,t=0.3对所有50个动作同时生效

Step5:动作-时间融合

action_time_emb = torch.cat([action_emb, time_emb], dim=-1)
  • 在最后一维(特征维)拼接动作嵌入和时间嵌入
  • 拼接后维度:(batch_size, 50, 2*proj_width)

这种拼接方式让MLP能够学习动作特征和时间信息之间的复杂交互关系,类似于论文中提到的"条件化"过程。

Step6:MLP处理

action_time_emb = self.action_time_mlp_in(action_time_emb)
action_time_emb = F.silu(action_time_emb)  # swish activation
action_time_emb = self.action_time_mlp_out(action_time_emb)

这几行都为MLP结构

# __init__函数
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)

维度变化:

  • 输入:(B, 50, 2048)
  • action_time_mlp_in: (B, 50, 1024)
  • SiLU激活
  • action_time_mlp_out: (B, 50, 1024)

SiLU激活函数选择:

SiLU(Sigmoid Linear Unit)被证明在transformer架构中比ReLU更有效,特别是在需要平滑梯度的场景中。

Step7:最终嵌入组装

embs = torch.cat([state_emb[:, None], action_time_emb], dim=1)

详细解析:

  • state_emb[:, None]:为状态嵌入增加序列维度 (B, s) → (B, 1, s)
  • 在序列维度(第1维)拼接状态和动作-时间嵌入

最终结构:

  • state_emb维度:(B, 1, proj_width) 1个状态token
  • action_time_emb维度:(B, 50, proj_width)    50个动作token
  • 最终embs维度: (B, 51, proj_width)   总共51个token

Step8:注意力掩码设置

att_masks = torch.zeros((bsize, action_time_dim + 1), device=device, dtype=torch.bool)
att_masks[:, :2] = True
  • att_masks[:, :2] = True:前两个位置设为True

为什么构造att_masks[:, :2] = True

  • 序列关系

    序列结构:[状态token, 动作token0, 动作token1, ..., 动作token49]
    位置索引:[    0    ,     1     ,     2     , ...,     50    ]
  • make_att_2d_masks函数的逻辑(后面详细解释)

    # mask_ar: 1表示"前面的token不能依赖这个token",即因果边界
    cumsum = torch.cumsum(att_masks, dim=1)
    # cumsum = [1, 2, 2, 2, ..., 2]
    att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
  • 注意力矩阵的生成

             位置: 0  1  2  3 ... 50
    位置0(状态):  ✓  ✗  ✗  ✗ ... ✗   # 只能attend自己
    位置1(动作):  ✓  ✓  ✓  ✓  ... ✓   # 能attend所有token  
    位置2(动作):  ✓  ✓  ✓  ✓  ... ✓   # 能attend所有token
    ...
    位置50(动作): ✓  ✓  ✓  ✓  ... ✓   # 能attend所有token
  • 位置0(状态):边界阻止它attend后面的token

  • 位置1(第1个动作):边界确保动作块开始,但动作内部没有更多边界

1D注意力向量转2D注意力矩阵 make_att_2d_masks

defmake_att_2d_masks(pad_masks, att_masks):
    """Copied from big_vision.

    Tokens can attend to valid inputs tokens which have a cumulative mask_ar    smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to    setup several types of attention, for example:      [[1 1 1 1 1 1]]: pure causal attention.      [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between          themselves and the last 3 tokens have a causal attention. The first          entry could also be a 1 without changing behaviour.      [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a          block can attend all previous blocks and all tokens on the same block.    Args:      input_mask: bool[B, N] true if its part of the input, false if padding.      mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on        it and 0 where it shares the same attention mask as the previous token.
    """

    if att_masks.ndim != 2:
        raise ValueError(att_masks.ndim)
    if pad_masks.ndim != 2:
        raise ValueError(pad_masks.ndim)

    cumsum = torch.cumsum(att_masks, dim=1)
    att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
    pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
    att_2d_masks = att_2d_masks & pad_2d_masks
    return att_2d_masks

Step 1: 输入验证

if att_masks.ndim != 2:
    raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
    raise ValueError(pad_masks.ndim)

确保输入都是2维tensor:(batch_size, seq_len)

Step 2: 累积和计算

cumsum = torch.cumsum(att_masks, dim=1)

关键理解:cumsum是注意力"组ID"

# π0中的实际情况
att_masks = [TrueTrueFalseFalse, ..., False]  # 长度51
#            状态  动作0  动作1   动作2      动作49

cumsum = torch.cumsum(att_masks, dim=1)
# cumsum = [1, 2, 2, 2, ..., 2]
#          状态 动作组的所有token都是组ID=2

cumsum的含义:

  • 每次遇到True(因果边界),组ID递增
  • 同一组内的token有相同的cumsum值
  • 不同组之间有因果关系

Step 3: 生成2D注意力掩码

att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]

维度变换规则:

cumsum.shape = (B, N)
cumsum[:, None, :].shape = (B, 1, N)  # 扩展为行向量
cumsum[:, :, None].shape = (B, N, 1)  # 扩展为列向量

# 广播比较得到
att_2d_masks.shape = (B, N, N)

注意力规则:att_2d_masks[i, j] = (cumsum[j] <= cumsum[i])

token_i可以attend token_j,当且仅当jcumsum ≤ icumsum

举个例子:

  • π0的实际数据:

    # 序列结构:[状态, 动作0, 动作1, 动作2, 动作3]
    att_masks = [11000]
    cumsum =    [12222]
  • 2D掩码矩阵计算:

              j=0  j=1  j=2  j=3  j=4  (被attend的token)
    i=0(状态)  1121212121  →  T F F F F
    i=1(动作)  1222222222  →  T T T T T  
    i=2(动作)  1222222222  →  T T T T T
    i=3(动作)  1222222222  →  T T T T T
    i=4(动作)  1222222222  →  T T T T T
  • 注意力矩阵解读:

    • 状态token():只能attend自己,形成孤立的注意力模式
    • 动作token():可以attend所有token,实现双向注意力

Step 4: Padding掩码应用

pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]

举个例子:

假设我们有一个batch,序列长度为5:

pad_masks = torch.tensor([
    [TrueTrueTrueFalseFalse]  # 前3个是真实token,后2个是padding
])

生成2D padding掩码的代码如下

pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]

其中

pad_masks[:, None, :] = [[[TrueTrueTrueFalseFalse]]]
# shape: (1, 1, 5)

pad_masks[:, :, None] = [[[True],
                          [True], 
                          [True],
                          [False],
                          [False]]]
# shape: (1, 5, 1)

根据广播乘法

# 广播后的乘法运算(逻辑与)
pad_2d_masks = [[[TrueTrueTrueFalseFalse]]] * [[[True],
                                                        [True], 
                                                        [True],
                                                        [False],
                                                        [False]]]

# 结果:pad_2d_masks.shape = (1, 5, 5)
pad_2d_masks = [
    [[ True,  True,  TrueFalseFalse],   # 位置0 * [T,T,T,F,F]
     [ True,  True,  TrueFalseFalse],   # 位置1 * [T,T,T,F,F] 
     [ True,  True,  TrueFalseFalse],   # 位置2 * [T,T,T,F,F]
     [FalseFalseFalseFalseFalse],   # 位置3 * [T,T,T,F,F]
     [FalseFalseFalseFalseFalse]]   # 位置4 * [T,T,T,F,F]
]

pad_2d_masks[i,j]的含义:

  • 只有当位置和位置都是有效token时,才为True
  • 即:两个padding位置之间不能有注意力连接

得到综合注意力矩阵

att_2d_masks = att_2d_masks & pad_2d_masks

既要满足因果关系,又要都是有效token

Flowmatching

从随机噪声生成50步机器人动作序列,使用flow matching的你   过程。


defsample_actions(
    self, images, img_masks, lang_tokens, lang_masks, state, noise=None
)
 -> Tensor:

    """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""    bsize = state.shape[0]    device = state.device    dtype = state.dtype    if noise isNone:        actions_shape = (            bsize,            self.config.n_action_steps,            self.config.max_action_dim,        )        noise = torch.randn(actions_shape, device=device, dtype=dtype)    prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(        images, img_masks, lang_tokens, lang_masks    )    prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)    prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1    # Compute image and language key value cache    _, past_key_values = self.paligemma_with_expert.forward(        attention_mask=prefix_att_2d_masks,        position_ids=prefix_position_ids,        past_key_values=None,        inputs_embeds=[prefix_embs, None],        use_cache=self.config.use_cache,        fill_kv_cache=True,    )    dt = torch.tensor(-1.0 / self.config.num_steps, dtype=dtype, device=device)    x_t = noise    time = torch.tensor(1.0, dtype=dtype, device=device)    while time >= -dt / 2:        expanded_time = time.expand(bsize)        v_t = self.predict_velocity(            state, prefix_pad_masks, past_key_values, x_t, expanded_time        )        # Euler step        x_t += dt * v_t        time += dt    return x_tdefpredict_velocity(self, state, prefix_pad_masks, past_key_values, x_t, timestep):    """predict velocity at time t using the suffix model."""    suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(        state, x_t, timestep    )    suffix_len = suffix_pad_masks.shape[1]    batch_size = prefix_pad_masks.shape[0]    prefix_len = prefix_pad_masks.shape[1]    prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(        batch_size, suffix_len, prefix_len    )    suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)    full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)    prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]    position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1    outputs_embeds, _ = self.paligemma_with_expert.forward(        attention_mask=full_att_2d_masks,        position_ids=position_ids,        past_key_values=past_key_values,        inputs_embeds=[None, suffix_embs],        use_cache=self.config.use_cache,        fill_kv_cache=False,    )    suffix_out = outputs_embeds[1]    suffix_out = suffix_out[:, -self.config.n_action_steps :]    v_t = self.action_out_proj(suffix_out)    return v_t

KV缓存在双向注意力中的数学原理

分块注意力机制的数学公式

我们回顾注意力机制(Self-Attention)的数学公式:

其中:

  • :查询矩阵
  • :键矩阵
  • :值矩阵

中,序列结构可以被表示为:

其中:

  • :图像+语言tokens
  • :状态+动作tokens

根据序列结构,对注意力进行分块表示

此时,注意力分数矩阵可以写为:

考虑注意力掩码矩阵

符号含义:

  • :prefix序列长度 = 图像token数量 + 语言token数量
  • :suffix序列长度 = 1个状态token + 50个动作token = 51

矩阵维度的具体解释:

**** 

  • 全1矩阵,大小为 
  • 表示prefix内部的双向注意力
  • 所有图像token和语言token之间可以相互attend

**** 

  • 全0矩阵,大小为 
  • 表示prefix不能attend suffix
  • 图像和语言token不能看到状态和动作token

**** 

  • 全1矩阵,大小为 
  • 表示suffix可以attend prefix
  • 状态和动作token可以看到所有图像和语言token

**** 

  • suffix内部的注意力掩码,大小为 
  • make_att_2d_masks函数根据因果规则生成
  • 控制状态token和动作token之间的注意力模式

π0典型的序列长度:

  • :取决于图像patch数量和语言token数量
  • :1个状态token + 50个动作token

掩码后的注意力分数为

Hadamard乘积(也叫逐元素乘积element-wise multiplication)的符号。

数学定义:

对于两个相同维度的矩阵  和 

即对应位置的元素相乘。

在注意力掩码中的应用:

在之前的公式中:

含义:

  • :原始注意力分数矩阵
  • :注意力掩码矩阵(True/False 或 1/0)
  • :应用掩码后的注意力分数

具体操作:

  • 掩码为1的位置:保留原始注意力分数
  • 掩码为0的位置:注意力分数变为0(阻止注意力连接)

示例:

这种操作实现了选择性的注意力屏蔽。

KV缓存的数学原理

prefix块的输出:

prefix输出的独立性:

即prefix的输出仅依赖于prefix自身,与suffix无关。


suffix块的输出:

suffix输出的依赖性:

即suffix的输出依赖于prefix和suffix。

缓存机制的数学实现

第一次前向传播

计算并缓存prefix的KV:

计算prefix的注意力输出:

后续前向传播(

重用缓存的prefix KV:

计算当前步的suffix KV:

拼接完整的KV矩阵:

计算suffix的注意力输出:

双向注意力兼容性的数学证明

依赖关系的单向性

prefix的计算不变性:

这是因为:

suffix的动态依赖:

计算复杂度分析

无缓存的复杂度:

有缓存的复杂度:

  • 首次:
  • 后续每步:

节省的计算量:

数学等价性验证

完整计算:

缓存计算:

等价性:

因为对所有成立。

Flowmatching的原理

π0中的时间参数定义

π0的插值公式(与标准相反):

其中:

  • :插值参数
  • 真实动作序列
  • 纯噪声
  • :30%噪声 + 70%真实动作

Flow Matching的速度场

速度场定义:

逆向积分的ODE:

推理过程:从噪声到真实数据

Step1:初始化噪声

if noise isNone:
    actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
    noise = torch.randn(actions_shape, device=device, dtype=dtype)

生成原始噪音数据

矩阵维度:(B, 50, 18)

Step2:Prefix处理和KV缓存优化

prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
    images, img_masks, lang_tokens, lang_masks
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1

# Compute image and language key value cache
_, past_key_values = self.paligemma_with_expert.forward(
    attention_mask=prefix_att_2d_masks,
    position_ids=prefix_position_ids,
    past_key_values=None,
    inputs_embeds=[prefix_embs, None],
    use_cache=self.config.use_cache,
    fill_kv_cache=True,
)

这一部分的代码在之前都进行详细拆解过,在此略过

Step3:Flow Macthing积分循环

dt = torch.tensor(-1.0 / self.config.num_steps, dtype=dtype, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=dtype, device=device)

while time >= -dt / 2:
    expanded_time = time.expand(bsize)
    v_t = self.predict_velocity(state, prefix_pad_masks, past_key_values, x_t, expanded_time)
    
    # Euler step
    x_t += dt * v_t
    time += dt

积分参数:

欧拉积分法:

如何求解这个速度,在下一个模块详细解释

积分路径:

对应的数据变化:

终止条件:

这确保最后一步能完整执行到 

predict_velocity函数详细解析

给定当前状态  和时间 ,预测速度场 

Step1:生成suffix嵌入

suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
    state, x_t, timestep
)
  • suffix_embs :1个状态token + 50个动作token的嵌入

    •  = proj_width = 1024 (action_expert的嵌入维度)
  • suffix_pad_masks : padding掩码,由于不存在padding,故全为True

  • suffix_att_masks :因果掩码,前2个为True,其余为False

Step2:提取维度信息

suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
  • suffix_len = 51 - suffix序列长度
  • batch_size = B - 批次大小
  • prefix_len = L_p - prefix序列长度(图像+语言token数量)

Step3:构建prefix到suffix的注意力掩码

prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
    batch_size, suffix_len, prefix_len
)

数学含义:

即:suffix的每个token(行)都可以attend prefix的所有有效token(列)。

可视化(假设prefix有5个token,suffix有3个token):

prefix_pad_masks = [True, True, True, False, False]

prefix_pad_2d_masks = [
    [True, True, True, False, False],  # suffix token 0
    [True, True, True, False, False],  # suffix token 1
    [True, True, True, False, False],  # suffix token 2
]

为什么这样设计?

  • suffix的每个token需要知道prefix中哪些是有效token
  • 通过复制prefix的padding掩码,确保suffix只attend有效的prefix token
  • 这是一个单向的跨注意力掩码

Step4:生成suffix内部的注意力掩码

suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)

回顾输入:

suffix_att_masks = [TrueTrueFalseFalse, ..., False]  # 长度51
#                   状态  动作0  动作1-49

make_att_2d_masks的计算:

生成的掩码矩阵:

           j=0  j=1  j=2  ... j=50
i=0(状态)    T    F    F   ...  F     # 状态只能attend自己
i=1(动作0)   T    T    T   ...  T     # 动作可以attend所有
i=2(动作1)   T    T    T   ...  T
...
i=50(动作49) T    T    T   ...  T

输出维度:

Step5:拼接完整的注意力掩码

full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)

拼接操作:

维度变化:

prefix_pad_2d_masks.shape = (B, 51, L_p)
suffix_att_2d_masks.shape = (B, 5151)

# 在第2维(列维度)拼接
full_att_2d_masks.shape = (B, 51, L_p + 51)

完整掩码矩阵的结构:

        [prefix tokens: L_p] [suffix tokens: 51]
状态    [     全部可见      ] [   只能看自己      ]
动作0   [     全部可见      ] [   全部可见        ]
动作1   [     全部可见      ] [   全部可见        ]
...
动作49  [     全部可见      ] [   全部可见        ]

数学表示:

Step6:计算位置编码

prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
  • prefix_offsets只统计有效token,自动跳过padding
  • cumsum确保suffix内部的位置是连续的

示例:

# Prefix的padding掩码 (B=1, L_p=10)
prefix_pad_masks = torch.tensor([    [TrueTrueTrueTrueTrueTrueTrueFalseFalseFalse]])#    有效tokens: 7个                              padding: 3个# Suffix的padding掩码 (B=1, L_s=6)suffix_pad_masks = torch.tensor([    [TrueTrueTrueTrueTrueTrue]])#    状态 + 5个动作,全部有效prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]# prefix_offsets = torch.tensor([[7]])  # shape: (1, 1)torch.cumsum(suffix_pad_masks, dim=1)# cumsum = torch.tensor([[1, 2, 3, 4, 5, 6]])  # shape: (1, 6)# Prefix的有效位置:0-6# position_ids = torch.tensor([[7, 8, 9, 10, 11, 12]])  # shape: (1, 6)

Step7:Transformer前向传播

outputs_embeds, _ = self.paligemma_with_expert.forward(
    attention_mask=full_att_2d_masks,
    position_ids=position_ids,
    past_key_values=past_key_values,
    inputs_embeds=[None, suffix_embs],
    use_cache=self.config.use_cache,
    fill_kv_cache=False,
)
  • attention_mask=full_att_2d_masks:

    • 形状:
    • 控制suffix_tokens的注意力模式
  • position_ids=position_ids:

    • 形状:
    • suffix tokens的位置编码
  • past_key_values=past_key_values:

    • 缓存的prefix KV矩阵
    • 包含:
  • inputs_embeds=[None, suffix_embs]:

    • None:表示不计算prefix的新嵌入
    • suffix_embs:当前步的suffix嵌入,形状 
  • use_cache=True: 保持KV缓存

  • fill_kv_cache=False: 不更新prefix的KV缓存

上述代码使用KV缓存优化了inputs_embeds = [prefix_embs, suffix_embs]步骤

  • 不需要重新计算prefix的嵌入和KV

  • 根据完整的注意力计算公式

    其中:

    直接从缓存中调用数据

    • :从past_key_values获取
    • :从suffix_embs计算

Step8:提取suffix输出

suffix_out = outputs_embeds[1]
  • outputs_embeds是一个列表:[prefix_output, suffix_output]
  • 我们只需要suffix的输出(以token形式存在)

Step9:提取动作tokens的输出

suffix_out = suffix_out[:, -self.config.n_action_steps :]

切片操作:

  • -self.config.n_action_steps = -50
  • 提取最后50个token(动作tokens)
  • 丢弃第一个token(状态token)

维度变化:

为什么丢弃状态token?

  • 状态token只用于提供上下文信息
  • 速度场只需要从动作tokens预测
  • 对应50步动作序列

Step10:投影到动作空间

v_t = self.action_out_proj(suffix_out)

维度变化:

输出含义:

Step11:返回速度场

return v_t

最终输出:

物理含义:

  • 在当前状态  和时间  下
  • 动作序列应该朝哪个方向"移动"
  • 这个方向由神经网络学习得到

声明:内容取材于网络,仅代表作者观点,如有内容违规问题,请联系处理。 
more
谷歌母公司 Alphabet 营收超预期,云业务成增长关键引擎 | 区势·BigTech
科技圈变电商圈?豆包们开始带货了
联合国全球契约组织代表团访问清华大学人工智能国际治理研究院
战略合作!安波福与Robust.AI携手开发人工智能协作机器人
中国移动总经理何飚:算网强基 数智驱动 聚力推动人工智能赋能新型工业化
股价飙升20%!刚刚,高通发布人工智能芯片!
一周全球公司十大要闻 | 苹果10月份在华手机销量同比激增37%;黄仁勋反驳对人工智能泡沫的担忧
爆涨20%!刚刚,高通发布人工智能芯片!
成都人工智能产业规模迈上千亿元台阶
这家由前腾讯人工智能专家创办的公司完成种子轮融资
Copyright © 2025 成都区角科技有限公司
蜀ICP备2025143415号-1
  
川公网安备51015602001305号