
源码来自@董子斌 https://github.com/ZibinDong/openpi_pytorch
从demo理解输入和输出
# 输入
observation = {
"image": {
"base_0_rgb": torch.randint(
0, 256, (1, 3, 224, 224), dtype=torch.uint8, device=device
),
# "left_wrist_0_rgb": ..., Suppose we don't have this view
# "right_wrist_0_rgb": ..., Suppose we don't have this view
},
"state": torch.randn(1, 8, device=device) * 0.2,
"prompt": ["do something"],
}
# 运行模型和action输出
action = policy.select_action(observation)[0, :, :7]
算法输入
观测信息observation包含三部分:图像、机器人状态向量、语言指令。
图像信息(observation["image"])
-
支持三路相机输入
可以只输入部分,缺失的会自动用掩码补齐
-
base_0_rgb(主相机) -
left_wrist_0_rgb(左侧腕部相机) -
right_wrist_0_rgb(右侧腕部相机) -
每张输入图片的格式为
uint8、范围[0,255],形状(*b, 3, H, W)
机器人状态 observation["state"]
-
float32,形状(*b, s),表示机器人的本体状态(例如关节角),最多32维度
语言指令 observation["prompt"]
算法输出
输出格式和维度
-
针对返回值
actions,其形状为(*b, config.n_action_steps, config.max_action_dim) -
在默认情况下,
-
n_action_steps = 50(输出一个“动作块/chunk”,又叫“动作地平线 ”) -
max_action_dim = 32(短动作向量会右侧 pad 到 32)。
针对代码
action = policy.select_action(observation)[0, :, :7]
表示取出第 1 个 batch,全部 50 个时间步,并截取前 7 个动作维度
动作维度和机器人怎么对应?
PI0 是“跨机型/跨体态”的统一策略,训练时把不同机器人统一到同一上限维度,并对不足的维度做零填充(state/action 都一样)。例如论文里总结了典型平台的“配置/动作维度”:
这些都统一 pad 到数据集最大维度(论文示例最大 18,代码实现允许到 32 作为上限)。使用时你按自己的机器人有效 DoF 截取前action_dim维即可(比如 UR5e 取前 7 维)。
为什么代码里面是:7?
因为示例里假设是 UR5e(6 关节 + 夹爪 = 7),于是 :7 正好取有效部分;后面的列是 padding。
-
单臂:UR5e(7 维:6 关节 + 1 夹爪);Franka(8 维:7 关节 + 1 夹爪) -
双臂:常见是 14 维(两只 6DoF 手臂 + 两个夹爪) -
移动操作臂:在 14 维基础上再加底盘控制维度(非全向底盘 +2,自由底盘 +3),所以 16 或 17 维
输出action的物理含义
-
关节(Arm Joints):
在 PI0Config 里有配置信息:
默认输出的关节部分就是 绝对值目标(absolute joint positions),如果在 ALOHA 适配场景中把 use_delta_joint_actions_aloha=True,则关节动作会变成相对于当前状态的增量 Δq。夹爪保持绝对值。
-
adapt_to_pi_aloha: bool = False -
use_delta_joint_actions_aloha: bool = False
-
夹爪(Gripper):
源码里明确注释:
“Gripper dimensions will remain in absolute values.”
所以无论 use_delta_joint_actions_aloha 是否开启,夹爪维度都是 绝对值。
-
移动底盘(Mobile Base) :
移动底盘采用 twist 指令,即:
twist 指令的坐标系为车身坐标系(body frame)
“These actions are in the robot’s local body frame.”
-
如果机器人有非全向底盘,则动作维度中包含 → 平移速度、角速度
-
如果是全向底盘,则包含 ;
从action输出到实际控制
-
控制频率:
论文明确:PI0 能以最高 50 Hz 控制执行(例如执行高灵巧操作)。也有平台以 20 Hz 运行(UR5e/Franka)
-
推理 / 执行的节奏(关键点)
开环地顺序执行动作块(作者尝试过块间“集成/平滑”策略,会变差,遂放弃)!
论文在附录 D 给了实际节奏:
-
20 Hz 机器人:每 0.8 s 重新前向一次(也就是先执行 16 步,再重算下一个动作块) -
50 Hz 机器人:每 0.5 s 前向一次(执行 25 步再重算)。
-
一次前向会生成一个50步的“动作块”( n_action_steps = 50) -
按给定的控制频率逐步执行这些动作(例如 50 Hz 就每 20 ms 执行下一步)
多模态数据的准备
policy.select_action(observation)的具体实现如下:
classPI0Policy(PreTrainedPolicy):
@torch.no_grad
defselect_action(
self, observation: dict[str, Tensor], noise: Tensor | None = None
):
"""
Observation: {
"image": {
"base_0_rgb": (*b, c, h, w), # uint8 [0, 255]
...
},
"state": float32 [*b, s],
"prompt": List[str],
"lang_tokens": float32 [*b, l],
"lang_masks": float32 [*b, l],
}
either provide `prompt` or (`lang_tokens`, `lang_masks`).
"""
self.eval()
images, img_masks = self.prepare_images(observation)
state = self.prepare_state(observation)
lang_tokens, lang_masks = self.prepare_language(observation)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
)
return actions
针对图像的预处理prepare_images
defprepare_images(self, observation: dict[str, Tensor]):
"""Normalize, resize, and pad images and stack them into a tensor.
Args:
observation (dict[str, Tensor])
Returns:
images (torch.Tensor): (*b, n, c, h, w) images in range [-1.0, 1.0]
img_masks (torch.Tensor): (*b, n) masks for images, True if image is present, False if missing
""" dtype = observation["state"].dtype # 获取状态的 dtype,通常是 float32 bsize = observation["state"].shape[0] # 获取批大小(batch size) images, img_masks = [], [] # 用于存储图像和掩码 # 确定哪些图像存在 present_img_keys = [key for key in IMAGE_KEYS if key in observation["image"]] missing_img_keys = [key for key in IMAGE_KEYS if key notin present_img_keys] # 处理存在的图像 for key in present_img_keys: img = observation["image"][key] # 获取图像 img = img.to(dtype) / 127.5 - 1.0# 归一化到 [-1, 1] 的范围 img = resize_with_pad( # 调整图像大小并填充 img, *self.config.resize_imgs_with_padding, pad_value=-1.0 ) images.append(img) # 存储处理过的图像 img_masks.append(torch.ones((bsize,), dtype=torch.bool, device=img.device)) # 对应掩码为 True # 处理缺失的图像(填充) for key in missing_img_keys: img = torch.full_like(img, fill_value=-1.0) # 填充 -1.0 images.append(img) # 存储填充图像 img_masks.append(torch.zeros((bsize,), dtype=torch.bool, device=img.device)) # 对应掩码为 False # 堆叠图像和掩码 images = torch.stack(images, dim=1) # (*b, n, c, h, w) img_masks = torch.stack(img_masks, dim=1) # (*b, n) return images, img_masks
Step 1:获取图像键、batch 信息与容器初始化
dtype = observation["state"].dtype # 一般是 torch.float32
bsize = observation["state"].shape[0] # batch_size (b)
images, img_masks = [], [] # 收集每个相机的图像与掩码
# 支持的相机顺序(决定输出的相机维度顺序)
IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
present_img_keys = [key for key in IMAGE_KEYS if key in observation["image"]]
missing_img_keys = [key for key in IMAGE_KEYS if key notin present_img_keys]
-
dtype:用 state 的 dtype,把图像转成相同精度,方便后续拼接/前向一致(通常float32)。 -
bsize:batch 大小b,后面所有图像、掩码都会以b为第一维。 -
IMAGE_KEYS:支持的相机视角("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb") -
present_img_keys:实际存在的相机键列表。 -
missing_img_keys:缺失的相机键列表(后面会用占位图和False掩码补齐)。
Step 2:图像预处理
# 处理存在的图像
for key in present_img_keys: img = observation["image"][key] # (b, 3, H, W),uint8,范围 [0, 255] img = img.to(dtype) / 127.5 - 1.0 # (b, 3, H, W),float,范围 [-1, 1] img = resize_with_pad( # 保持比例 resize 到 224,并用 -1.0 pad 到 (224, 224) img, *self.config.resize_imgs_with_padding, pad_value=-1.0 ) # (b, 3, 224, 224) images.append(img) # 先以“列表”形式收集 (b, 3, 224, 224) img_masks.append( torch.ones((bsize,), dtype=torch.bool, device=img.device) ) # (b,) = 全 True(这一路相机存在)# 处理缺失的图像(填充)for key in missing_img_keys: img = torch.full_like(img, fill_value=-1.0) # 填充 -1.0 images.append(img) # 存储填充图像 img_masks.append(torch.zeros((bsize,), dtype=torch.bool, device=img.device)) # 对应掩码为 False
-
归一化:把图像像素值从 [0, 255]转换为[-1, 1]。 -
resize 和 padding:将图像调整为指定的尺寸(在配置中为 (224, 224))并进行填充。填充是为了统一图像大小,保持输入张量的一致性。如果某些图像没有数据,预设会用值为-1.0的填充值填充。 -
暂存 images和img_masks(到下一个Step进行讲解)
Step3:图像和掩码堆叠
# 堆叠图像和掩码
images = torch.stack(images, dim=1) # (*b, n, c, h, w)
img_masks = torch.stack(img_masks, dim=1) # (*b, n)
图像堆叠(images):
将所有图像堆叠成一个五维张量,形状是 (b, n, 3, 224, 224),其中:
-
b是批大小(batch size), -
n是有效的图像数量(可能是 1, 2 或 3), -
3是通道数(RGB图像), -
224和224是图像的高度和宽度。
掩码堆叠(img_masks)
对应的掩码张量 img_masks 也会堆叠,形状是 (b, n),表示每张图像是否有效
-
如果图像存在,掩码的值是 True,表示图像有效。 -
如果图像缺失,掩码的值是 False,表示图像缺失或无效。
这里给出当有效图像只有"base_0_rgb", "left_wrist_0_rgb"时的例子
-
present_img_keys:["base_0_rgb", "left_wrist_0_rgb"] -
missing_img_keys:["right_wrist_0_rgb"] -
下图给出 image_masks的堆叠情况 -
针对 images的堆叠,维度和下图类似,每个1或0的位置一张图片矩阵或由-1填充得到的矩阵

针对机器人状态的预处理prepare_state
defprepare_state(self, observation: dict[str, Tensor]):
"""Pad the state to the maximum state dimension.
Args:
observation (dict[str, Tensor])
Returns:
state (torch.Tensor): (*b, max_state_dim) padded state tensor
"""
state = observation["state"]
state = F.pad(state, (0, self.config.max_state_dim - state.shape[1]))
return state
把 observation["state"]从原始维度(b, s)右侧补零到 (b, max_state_dim),使每个样本的状态向量长度一致,便于后续喂给固定输入宽度的线性层。
torch.nn.functional.pad(input, pad, mode='constant', value=0)
pad 参数是一个元组,从 最后一维开始,成对指定 (left, right) 的填充数量。
-
pad是一个元组,从最后一维开始,成对指定(pad_left, pad_right)。 -
如果张量是多维的,就继续往前一维一维写。例如 2D 就需要两个维度的 pad(即 4 个数),3D 需要 6 个数,以此类推。 -
默认填充值是 0,可以通过value=...修改。
Example:
x = torch.tensor([1,2,3]) # pad=(l, r), 在左边补 l 个数, 在右边补 r 个数
F.pad(x, (2,1)) # → [0,0,1,2,3,0]
x = torch.arange(6).reshape(2,3) # shape (2,3)
# [[0,1,2],
# [3,4,5]]
y = F.pad(x, (1,2, 2,1)) # pad=(w_left, w_right, h_left, h_right)
# w_left=1, w_right=2, h_left=2, h_right=1
# 输出形状 (2+2+1, 3+1+2) = (5,6)
x = torch.zeros(1,2,3) # shape (C=1,H=2,W=3)
# pad=(w_left, w_right, h_left, h_right, c_left, c_right)
y = F.pad(x, (1,1, 2,0, 0,2))
# W 左右各补1 → W=5
# H 上补2下补0 → H=4
# C 前补0后补2 → C=3
# y.shape == (3,4,5)
针对任务指令的预处理prepare_language
defprepare_language(self, observation: dict[str, Tensor]):
"""
返回:
lang_tokens: (*b, l) # int 型 token ids
lang_masks : (*b, l) # bool 掩码, True=该位置有有效 token, False=padding
""" lang_tokens = observation.get("lang_tokens", None) lang_masks = observation.get("lang_masks", None) prompt = observation.get("prompt", None) # 必须二选一: prompt 或 (lang_tokens, lang_masks) if prompt isNoneand (lang_tokens isNoneor lang_masks isNone): raise ValueError(...) device = observation["state"].device if prompt isnotNoneand (lang_tokens isNoneor lang_masks isNone): # 1) 规范化文本到 PaliGemma 期望格式 prompt = [p if p.startswith("" ) elsef" {p}"for p in prompt] prompt = [p if p.endswith("\n") elsef"{p}\n"for p in prompt] # 2) 调 tokenizer -> 得到 (b, L) 的 ids 与 mask tokenized_prompt = self.language_tokenizer.__call__( prompt, padding="max_length", padding_side="right", max_length=self.config.tokenizer_max_length, return_tensors="pt", ) lang_tokens = tokenized_prompt["input_ids"].to(device=device) lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) else: # 直接使用用户提供的 token/mask lang_tokens = observation["lang_tokens"].to(device=device) lang_masks = observation["lang_masks"].to(device=device, dtype=torch.bool) return lang_tokens, lang_masks
输入模式与选择逻辑
-
可以传 原始文本
prompt: List[str],也可以直接传 已分好词的lang_tokens和lang_masks(两者必须成对)。 -
批次大小
b来自observation["state"].shape[0]。无论是prompt还是lang_tokens/lang_masks,第一维都要等于b。
使用 prompt 的处理流程(最常见)
Step1:规范化到 PaliGemma 前缀格式
prompt = [p if p.startswith("") else f"{p}" for p in prompt]
prompt = [p if p.endswith("\n") else f"{p}\n" for p in prompt]
-
在开头加 ,在结尾加换行\n(PaliGemma 以\n作为分隔的约定)。 -
如果原文本已经有 或已以\n结尾,就不会重复添加。
Step2:tokenizer
tokenized_prompt = self.language_tokenizer.__call__(
prompt,
padding="max_length", # 不足 L 的右侧 pad 到 L
padding_side="right",
max_length=self.config.tokenizer_max_length, # L,默认 48
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device) # 形状: (b, L), dtype=torch.long
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) # 形状: (b, L), True=非pad
矩阵维度:
-
lang_tokens:(b, L),整型 token id(通常是torch.long) -
lang_masks:(b, L),布尔掩码; -
True 表示该位置是有效 token(包括 与正文) -
False 表示 padding。
padding:
-
右侧补齐( padding_side="right") -
上面没有显式写 truncation=True。大多数 HF tokenizer 在max_length+padding="max_length"情况下会对超长序列进行截断(同时给出 warning)。
使用 lang_tokens/lang_masks 的处理流程
lang_tokens = observation["lang_tokens"].to(device=device) # (b, L)
lang_masks = observation["lang_masks"].to(device=device, dtype=torch.bool) # (b, L)
-
传入的tokens/masks 将被直接使用,不会再做
/\n之类的改写,也不会再 padding。 -
请务必保证:
-
形状第一维等于 batch 大小 b; -
第二维 L不应超过config.tokenizer_max_length(否则后续拼接位置可能不对齐); -
lang_tokens.dtype应为 整型(通常torch.long),因为后续会走 embedding; -
lang_masks.dtype为 bool,True=非 pad,False=pad。
信息嵌入
把“图像 + 语言提示”当作前缀(prefix) ,把“状态/动作/时间步”等控制相关的 token 当作后缀(suffix)
前缀信息嵌入 embed_prefix
总览
defembed_prefix(
self, images, img_masks, lang_tokens, lang_masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer to prepare
for PaliGemma transformer processing.
Args:
images (torch.Tensor): float (*b, n, c, h, w) images in range [-1.0, 1.0]
img_masks (torch.Tensor): bool (*b, n) masks for images
lang_tokens (torch.Tensor): int (*b, l) language tokens
lang_masks (torch.Tensor): bool (*b, l) masks for language tokens
""" bsize = images.shape[0] device = images.device dtype = images.dtype # embed image images = einops.rearrange(images, "b n c h w -> (b n) c h w") img_emb = self.paligemma_with_expert.embed_image(images) num_patch = img_emb.shape[1] img_emb = einops.rearrange(img_emb, "(b n) l d -> b (n l) d", b=bsize) img_emb = img_emb.to(dtype=dtype) * (img_emb.shape[-1] ** 0.5) num_img_embs = img_emb.shape[1] img_masks = einops.repeat(img_masks, "b n -> b (n l)", l=num_patch) # embed language lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) num_lang_embs = lang_emb.shape[1] lang_emb = lang_emb.to(dtype=dtype) * np.sqrt(lang_emb.shape[-1]) # assemble embeddings embs = torch.cat([img_emb, lang_emb], dim=1) pad_masks = torch.cat([img_masks, lang_masks], dim=1) # PaliGemma uses bidirectional attention for prefix tokens, # so we set 1D `att_masks` to zeros. # (see `make_att_2d_masks` to understand why zeros means bidirection) att_masks = torch.zeros( (bsize, num_img_embs + num_lang_embs), device=device, dtype=torch.bool ) return embs, pad_masks, att_masks
embed_prefix(...) 的职责就是把前缀两种模态(图像、文本)各自编码成同一维度的 token 序列,然后沿着序列维拼起来,并产出两类 mask:
-
pad_masks:哪一些位置是真实 token、哪一些是 padding(缺失图像或被 tokenizer pad 的 token)。 -
att_masks(1D) :前缀 token 的注意力“类型”标记(这里全部设为 0),供后续 make_att_2d_masks展开为真正的二维注意力矩阵时使用。对前缀,含义是“双向注意力”。
为什么前缀要双向?
PaliGemma 的前缀部分(图像 patch + 文本 token)用双向自注意力来做跨模态对齐/融合,这样所有前缀 token 能互相看见,利于把视觉上下文和语言提示融合后,再把“控制后缀”接进来驱动动作专家(Expert Gemma)。
Step1:图像编码(SigLIP 视觉编码器 → patch token)
images = einops.rearrange(images, "b n c h w -> (b n) c h w")
-
输入 images形状是(*b, n, c, h, w),n是视角数(top、left_wrist、right_wrist 等)。 -
这里把 b和n合起来,变成(b*n, c, h, w),这样可以直接喂给视觉编码器做逐图前向。
img_emb = self.paligemma_with_expert.embed_image(images)
调 SigLIP/ViT-style 的视觉编码:输出形状通常是 ((b*n), L, D)
-
L:每张图切成多少个 patch,再加上可能的特殊 token(具体是否有 CLS 以实现为准;从后面逻辑看,它把所有长度叫num_patch用于 mask 展开)。对于 224×224、patch 16,L ≈ 14×14=196(若有 CLS 就 197)。 -
D:视觉嵌入维度(与 PaliGemma 主干的隐藏维一致或经投影一致,代码里稍后只用shape[-1])
num_patch = img_emb.shape[1]
# 再把 (b*n) 还原成 b,并把 “n 张图的 L 个 patch”摊平成一段连续序列
img_emb = einops.rearrange(img_emb, "(b n) l d -> b (n l) d", b=bsize)
还原后 img_emb 形状是 (b, n*L, D):也就是把多个相机的所有 patch 串接成“一段前缀图像 token 序列”。
# 统一 dtype(比如到 fp16/bf16),并做 “√D 放缩”
img_emb = img_emb.to(dtype=dtype) * (img_emb.shape[-1] ** 0.5)
num_img_embs = img_emb.shape[1]
-
num_img_embs = n*L,就是图像前缀 token 的总长度。
❝为什么乘
√D?
Transformer 的经典做法(源自 Vaswani et al.)是把词嵌入乘以 √d_model,以便和位置编码的量级匹配、并在层归一化前维持稳定的方差尺度。此处把视觉 patch 嵌入也乘 √D,让“图像 token”和后面“文本 token”在数值尺度上对齐,避免某一模态数值太小/太大导致注意力分布偏置。
# 把每张图的存在性 mask,从 (b, n) 展开到 (b, n*L)
img_masks = einops.repeat(img_masks, "b n -> b (n l)", l=num_patch)
-
先前
prepare_images里,如果某个视角缺失,会放一张全 -1 的占位图,同时img_masks里该视角是 False。 -
这里展开后,每张图的 每个 patch 都继承该视角的 True/False。
-
这个 mask 之后用于:
-
构造二维注意力 mask 时屏蔽掉“来自缺失视角的 patch token”; -
计算 position_ids(position_ids = cumsum(pad_masks)-1),使 padding 的位置得到 -1(无效)。
为什么要扩维度?
einops.repeat(img_masks, "b n -> b (n l)", l=num_patch) 的作用就是 把每个相机的存在性掩码扩展到该相机的所有 patch token 上,保证和图像嵌入序列长度一致,从而正确屏蔽缺失视角的所有 patch。
原始形状:
-
img_masks在进入embed_prefix前是(b, n) -
b= batch size -
n= 视角数(top camera, left wrist, right wrist …)
每个位置只表示“这一张图是否存在”。 -
images在进入embed_prefix前是(b, n, c, h, w) -
images经过视觉编码器,每张图不再是一个整体,而是被切成num_patch个 patch(token),每个token向量的长度是。所以图像嵌入img_emb的形状是(b, n * num_patch, d)。
扩展原因:
因为 img_emb 有 n * num_patch个 token,而 img_masks 只有 n 个标志位。如果我们保持 img_masks 不变((b, n)),就没法和 img_emb 对齐。
所以要把 img_masks广播/扩展,让它为每个 patch 也提供一个 mask 值:
img_masks = einops.repeat(img_masks, "b n -> b (n l)", l=num_patch)
新形状是 (b, n * num_patch),正好能和 img_emb 的序列长度对齐。
举个例子:
假设:
-
batch size b=1 -
只有 2 个相机 n=2 -
每张图切成 num_patch=3
那么:
-
img_masks原始:[[1, 0]](第一个相机存在,第二个相机缺失) -
扩展后:
[[1, 1, 1, 0, 0, 0]] -
前 3 个 patch 来自第一个相机 → 有效 -
后 3 个 patch 来自第二个相机 → 全部 mask 掉
这样,后续 Transformer 在做注意力时,就不会“看见”缺失相机的 patch。
Step2:文本编码(Gemma 词嵌入层)
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
num_lang_embs = lang_emb.shape[1]
lang_emb = lang_emb.to(dtype=dtype) * np.sqrt(lang_emb.shape[-1])
-
embed_language_tokens一般是查词表 + 线性投影或直接查表,输出(b, l, d)。
img_emb 的最后一维度和lang_emb 的最后一维度是一样的。因为这两个模态(图像 patch token 和语言 token)最后要 拼接在一起 形成一条连续的前缀序列,然后送进同一个 PaliGemma Transformer。Transformer 的输入必须在最后一维(hidden size)对齐,否则没法一起做注意力和线性变换。
-
同样做了 乘 √D 的放缩,理由与图像相同:把语言 token 的数值尺度和视觉 token 对齐,利于后续把两段序列拼接、共享同一个 Transformer 主干(PaliGemma)。
拼接序列 + 组合 mask
embs = torch.cat([img_emb, lang_emb], dim=1) # (b, 图像个数*每个图像的token数 + 语言token数, D)
pad_masks = torch.cat([img_masks, lang_masks], dim=1) # (b, n*L + l)
-
现在前缀序列是:“所有图像 patch token 在前,语言 token 在后”。 -
这一步确保两个模态统一在一条时间轴/位置轴上,方便后面一口气喂进 PaliGemma 的 Transformer。
生成前缀的 1D 注意力类型 mask(全 0 = 双向)
# 前缀使用双向注意力:把 1D att_masks 全部设为 0
# (make_att_2d_masks 里会把 0 解读为“bidirectional 前缀”,生成全互看的注意力矩阵)
att_masks = torch.zeros(
(bsize, num_img_embs + num_lang_embs), device=device, dtype=torch.bool
)
return embs, pad_masks, att_masks
-
这里的 att_masks不是“padding mask”,也不是最终的二维注意力 mask,而是一条 “每个 token 的注意力类型标记” 。 -
在后续的代码中, make_att_2d_masks会生成2D注意力矩阵
尽管att_masks 是2D矩阵,但由于第一维是batch,那么,对于一个 batch,att_masks[i] 就是一行向量,长度等于序列长度 seq_len。即,从“单条序列”的角度看,它就是 1D 的向量。
后缀信息嵌入embed_suffix
defembed_suffix(self, state, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.
Args:
state (torch.Tensor): float32 (*b, s) robot state
noisy_actions (torch.Tensor): float32 (*b, n, m) noisy actions
timestep (torch.Tensor): float32 (*b,) timestep in [0, 1] range
""" bsize = state.shape[0] device = state.device dtype = state.dtype # embed state state_emb = self.state_proj(state) # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] time_emb = create_sinusoidal_pos_embedding( timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device, ) time_emb = time_emb.type(dtype=dtype) # Fuse timestep + action information using an MLP action_emb = self.action_in_proj(noisy_actions) time_emb = einops.repeat(time_emb, "b d -> b n d", n=action_emb.shape[1]) action_time_emb = torch.cat([action_emb, time_emb], dim=-1) action_time_emb = self.action_time_mlp_in(action_time_emb) action_time_emb = F.silu(action_time_emb) # swish == silu action_time_emb = self.action_time_mlp_out(action_time_emb) action_time_dim = action_time_emb.shape[1] # Add to input tokens embs = torch.cat([state_emb[:, None], action_time_emb], dim=1) pad_masks = torch.ones( (bsize, action_time_dim + 1), device=device, dtype=torch.bool ) # Set attention masks for suffix tokens so that prefix tokens cannot attend to suffix tokens. # And state token cannot attend action tokens. # Action tokens use a bidirectional attention. att_masks = torch.zeros( (bsize, action_time_dim + 1), device=device, dtype=torch.bool ) att_masks[:, :2] = True return embs, pad_masks, att_masks
noisy_actions 的维度结构
noisy_actions.shape = (batch_size, n_action_steps, max_action_dim)
各维度含义:
-
第0维 (batch_size) :批次大小,通常为1或更多 -
第1维 (n_action_steps) :动作序列长度,默认为50步 -
第2维 (max_action_dim) :动作的最大维度,为了支持不同机器人而设置为18
两个时间概念
Flow Matching中的
这是生成过程的数学参数,不是物理时间,可以将其理解为去噪进度:
-
t=0:完全的噪声数据 -
t=1:完全的真实动作序列 -
t=0.3:30%真实动作 + 70%噪声的混合
这个过程发生在模型推理时,不是机器人执行时
机器人的实际执行时间
-
物理时刻: 执行动作 -
物理时刻: 执行动作 -
物理时刻2: 执行动作 -
... -
物理时刻: 执行动作
这些是真实世界的50个不同时刻,与flow matching的t无关。
Step1:状态嵌入
state_emb = self.state_proj(state)
-
self.state_proj是一个线性层:nn.Linear(self.config.max_state_dim, self.config.proj_width) -
将机器人状态从原始维度(如关节角度)映射到transformer的嵌入维度 -
例如:7维关节角度(已经在 prepare_state中扩充为max_state_dim维度) → 1024维嵌入向量
Step2:时间步嵌入
time_emb = create_sinusoidal_pos_embedding(
timestep, # flowmathing里面的时间
self.config.proj_width,
min_period=4e-3,
max_period=4.0,
device=device,
)
time_emb = time_emb.type(dtype=dtype)
调用 create_sinusoidal_pos_embedding 函数,完成时间步嵌入
defcreate_sinusoidal_pos_embedding(
time: torch.tensor,
dimension: int,
min_period: float,
max_period: float,
device="cpu",
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions.""" if dimension % 2 != 0: raise ValueError(f"dimension ({dimension}) must be divisible by 2") if time.ndim != 1: raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") fraction = torch.linspace( 0.0, 1.0, dimension // 2, dtype=torch.float32, device=device ) period = min_period * (max_period / min_period) ** fraction # Compute the outer product scaling_factor = 1.0 / period * 2 * math.pi sin_input = scaling_factor[None, :] * time[:, None] pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) return pos_emb
数学公式
公式推导
-
批大小为 ,输出维度为 (要求 为偶数),令 。
-
设最小周期 ,最大周期 ,令 。
-
当 时,频率在对数尺度上线性分布:
对 ,定义
当 时,取唯一的周期
-
对应的角频率为
-
给定一个标量时间/位置 ,定义正弦-余弦位置编码为
-
给定一组时间 时
这里的 指的是每个batch 的时间步值,也就是说和没有关系令
则外积
最终的嵌入矩阵
其中“
|”表示按列拼接。频率范围端点(当 ):
且在对数尺度上线性均匀覆盖区间 。
为什么这样设计
-
多尺度覆盖(log 均匀频率)
令周期按指数(几何)级数分布,等价于频率在对数尺度上线性分布,可同时覆盖低频(大周期)与高频(小周期)模式,捕捉从缓慢变化到快速变化的多尺度信息。
-
使用的原因
周期为 的正弦波应写为 。写成 可直接用角频率,单位一致并且语义直观: 就是“实际周期”。
-
成对的 sin 和 cos
对任意位移 Δ,
即平移对应于每个频率通道上的线性旋转变换,这使得模型更容易从绝对位置泛化到相对位移。
-
内积只与相对位置有关(对每个频率对)
对同一频率通道,利用恒等式 :
-
整体嵌入的内积成为多频率余弦核的和:
只依赖差值 ,有利于捕捉相对位置信息。
-
维度必须为偶数
每个频率占用两个维度(sin 与 cos),故 需为偶数,。
公式和代码的逐行对应
1.维度与形状检查
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
-
要求 为偶数,即 -
期望 的形状是 ,即
-
对数线性采样周期(或对数线性采样频率)
fraction = torch.linspace(
0.0, 1.0, dimension // 2, dtype=torch.float32, device=device
)
period = min_period * (max_period / min_period) ** fraction
3.角频率计算
scaling_factor = 1.0 / period * 2 * math.pi
-
外积(批量化)与逐元素非线性
sin_input = scaling_factor[None, :] * time[:, None]
-
scaling_factor原形状为 。加上None(等价于unsqueeze(0))后变成 ,即一行K列 -
time原形状为 。用[:, None](等价于unsqueeze(-1))后变成 ,即B行一列。 -
按 PyTorch 的广播规则: 与 相乘,会在维度为 1的位置进行扩展,结果形状为 。 -
任意位置 的元素为,这正是外积 的元素定义
等效的其他写法
-
sin_input = time[:,None]∗scaling_factor[None,:] -
sin_input=time[:,None] @ scaling_factor[None,:] -
sin_input=torch.outer(time, scaling_factor)
torch.sin(sin_input), torch.cos(sin_input)
-
对 的每个元素施加 sin、cos
-
拼接得到最终嵌入
pos_emb = torch.cat([sin(...), cos(...)], dim=1)
示例
-
选取:
-
参数设定
-
维度 → 频率通道数
-
-
-
周期按几何级数:
-
角频率:
-
取批量时间向量:
-
计算过程
-
先算外积相位矩阵
-
对 逐元素取正弦与余弦,并按列拼接得到最终嵌入:
-
即
4.验证代码
import math
import torchfrom torch import Tensor# 你的函数(原样拷贝)defcreate_sinusoidal_pos_embedding(
time: torch.tensor,
dimension: int,
min_period: float,
max_period: float,
device="cpu",
) -> Tensor: """Computes sine-cosine positional embedding vectors for scalar positions.""" if dimension % 2 != 0: raise ValueError(f"dimension ({dimension}) must be divisible by 2") if time.ndim != 1: raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") fraction = torch.linspace( 0.0, 1.0, dimension // 2, dtype=torch.float32, device=device ) period = min_period * (max_period / min_period) ** fraction scaling_factor = 1.0 / period * 2 * math.pi sin_input = scaling_factor[None, :] * time[:, None] pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) return pos_emb# 例子 A 的参数与 tD = 8min_period = 0.125max_period = 1.0t = torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0], dtype=torch.float32)# 用你的函数计算pos_code = create_sinusoidal_pos_embedding( time=t, dimension=D, min_period=min_period, max_period=max_period, device="cpu")# 用“闭式公式矩阵”(上面手算得到的矩阵)作为期望结果pos_formula_closed = torch.tensor([ [0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, -1, 0], [0, 0, 0, 0, 1, 1, 1, -1], [0, 0, 0, -1, 1, 1, -1, 0], [0, 0, 0, 0, 1, 1, 1, 1],], dtype=torch.float32)# 校验(允许极小的浮点误差)ok = torch.allclose(pos_code, pos_formula_closed, rtol=1e-3, atol=1e-3)max_abs_diff = (pos_code - pos_formula_closed).abs().max().item()print("allclose:", ok)print("max_abs_diff:", max_abs_diff)print("pos_code:\n", pos_code)
Step3:动作嵌入
action_emb = self.action_in_proj(noisy_actions)
-
action_in_proj是一个线性投影层nn.Linear(self.config.max_action_dim, self.config.proj_width)
proj_width是啥?
-
从论文附录B可以看到,action expert使用较小的嵌入维度(嵌入向量的特征维度,即每个嵌入向量的长度)以加速推理:
-
proj_width=1024,这比主VLM backbone(2048)小,用于平衡性能和速度主VLM backbone使用预训练的PaliGemma模型,其嵌入维度已经固定
self.paligemma_with_expert = PaliGemmaWithExpertModel(
paligemma_with_export_config
)
Step4:时间嵌入扩展
time_emb = einops.repeat(time_emb, "b d -> b n d", n=action_emb.shape[1])
-
time_emb原始维度:(batch_size, proj_width) -
n=action_emb.shape[1]:即n_action_steps=50 -
einops.repeat:将时间嵌入在第1维重复n次
详细理解:
50个动作共享一个去噪时间步
在flow matching中,时间步 是全局的,但每个动作token都需要知道当前的"去噪进度"
怎么理解?
time_step = 0.3 # 全局去噪进度30%
time_emb = repeat(time_step_embedding, 50) # 复制给50个token,t=0.3对所有50个动作同时生效
Step5:动作-时间融合
action_time_emb = torch.cat([action_emb, time_emb], dim=-1)
-
在最后一维(特征维)拼接动作嵌入和时间嵌入 -
拼接后维度: (batch_size, 50, 2*proj_width)
这种拼接方式让MLP能够学习动作特征和时间信息之间的复杂交互关系,类似于论文中提到的"条件化"过程。
Step6:MLP处理
action_time_emb = self.action_time_mlp_in(action_time_emb)
action_time_emb = F.silu(action_time_emb) # swish activation
action_time_emb = self.action_time_mlp_out(action_time_emb)
这几行都为MLP结构
# __init__函数
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
维度变化:
-
输入: (B, 50, 2048) -
action_time_mlp_in: (B, 50, 1024) -
SiLU激活 -
action_time_mlp_out: (B, 50, 1024)
SiLU激活函数选择:
SiLU(Sigmoid Linear Unit)被证明在transformer架构中比ReLU更有效,特别是在需要平滑梯度的场景中。
Step7:最终嵌入组装
embs = torch.cat([state_emb[:, None], action_time_emb], dim=1)
详细解析:
-
state_emb[:, None]:为状态嵌入增加序列维度(B, s) → (B, 1, s) -
在序列维度(第1维)拼接状态和动作-时间嵌入
最终结构:
-
state_emb维度: (B, 1, proj_width)1个状态token -
action_time_emb维度: (B, 50, proj_width)50个动作token -
最终embs维度: (B, 51, proj_width)总共51个token
Step8:注意力掩码设置
att_masks = torch.zeros((bsize, action_time_dim + 1), device=device, dtype=torch.bool)
att_masks[:, :2] = True
-
att_masks[:, :2] = True:前两个位置设为True
为什么构造att_masks[:, :2] = True?
-
序列关系
序列结构:[状态token, 动作token0, 动作token1, ..., 动作token49]
位置索引:[ 0 , 1 , 2 , ..., 50 ] -
make_att_2d_masks函数的逻辑(后面详细解释)
# mask_ar: 1表示"前面的token不能依赖这个token",即因果边界
cumsum = torch.cumsum(att_masks, dim=1)
# cumsum = [1, 2, 2, 2, ..., 2]
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] -
注意力矩阵的生成
位置: 0 1 2 3 ... 50
位置0(状态): ✓ ✗ ✗ ✗ ... ✗ # 只能attend自己
位置1(动作): ✓ ✓ ✓ ✓ ... ✓ # 能attend所有token
位置2(动作): ✓ ✓ ✓ ✓ ... ✓ # 能attend所有token
...
位置50(动作): ✓ ✓ ✓ ✓ ... ✓ # 能attend所有token -
位置0(状态):边界阻止它attend后面的token
-
位置1(第1个动作):边界确保动作块开始,但动作内部没有更多边界
1D注意力向量转2D注意力矩阵 make_att_2d_masks
defmake_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to setup several types of attention, for example: [[1 1 1 1 1 1]]: pure causal attention. [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between themselves and the last 3 tokens have a causal attention. The first entry could also be a 1 without changing behaviour. [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a block can attend all previous blocks and all tokens on the same block. Args: input_mask: bool[B, N] true if its part of the input, false if padding. mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
att_2d_masks = att_2d_masks & pad_2d_masks
return att_2d_masks
Step 1: 输入验证
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
确保输入都是2维tensor:(batch_size, seq_len)
Step 2: 累积和计算
cumsum = torch.cumsum(att_masks, dim=1)
关键理解:cumsum是注意力"组ID"
# π0中的实际情况
att_masks = [True, True, False, False, ..., False] # 长度51
# 状态 动作0 动作1 动作2 动作49
cumsum = torch.cumsum(att_masks, dim=1)
# cumsum = [1, 2, 2, 2, ..., 2]
# 状态 动作组的所有token都是组ID=2
cumsum的含义:
-
每次遇到True(因果边界),组ID递增 -
同一组内的token有相同的cumsum值 -
不同组之间有因果关系
Step 3: 生成2D注意力掩码
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
维度变换规则:
cumsum.shape = (B, N)
cumsum[:, None, :].shape = (B, 1, N) # 扩展为行向量
cumsum[:, :, None].shape = (B, N, 1) # 扩展为列向量
# 广播比较得到
att_2d_masks.shape = (B, N, N)
注意力规则:att_2d_masks[i, j] = (cumsum[j] <= cumsum[i])
token_i可以attend token_j,当且仅当j的cumsum ≤ i的cumsum
举个例子:
-
π0的实际数据:
# 序列结构:[状态, 动作0, 动作1, 动作2, 动作3]
att_masks = [1, 1, 0, 0, 0]
cumsum = [1, 2, 2, 2, 2] -
2D掩码矩阵计算:
j=0 j=1 j=2 j=3 j=4 (被attend的token)
i=0(状态) 1≤12≤12≤12≤12≤1 → T F F F F
i=1(动作) 1≤22≤22≤22≤22≤2 → T T T T T
i=2(动作) 1≤22≤22≤22≤22≤2 → T T T T T
i=3(动作) 1≤22≤22≤22≤22≤2 → T T T T T
i=4(动作) 1≤22≤22≤22≤22≤2 → T T T T T -
注意力矩阵解读:
-
状态token():只能attend自己,形成孤立的注意力模式 -
动作token():可以attend所有token,实现双向注意力
Step 4: Padding掩码应用
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
举个例子:
假设我们有一个batch,序列长度为5:
pad_masks = torch.tensor([
[True, True, True, False, False] # 前3个是真实token,后2个是padding
])
生成2D padding掩码的代码如下
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
其中
pad_masks[:, None, :] = [[[True, True, True, False, False]]]
# shape: (1, 1, 5)
pad_masks[:, :, None] = [[[True],
[True],
[True],
[False],
[False]]]
# shape: (1, 5, 1)
根据广播乘法
# 广播后的乘法运算(逻辑与)
pad_2d_masks = [[[True, True, True, False, False]]] * [[[True],
[True],
[True],
[False],
[False]]]
# 结果:pad_2d_masks.shape = (1, 5, 5)
pad_2d_masks = [
[[ True, True, True, False, False], # 位置0 * [T,T,T,F,F]
[ True, True, True, False, False], # 位置1 * [T,T,T,F,F]
[ True, True, True, False, False], # 位置2 * [T,T,T,F,F]
[False, False, False, False, False], # 位置3 * [T,T,T,F,F]
[False, False, False, False, False]] # 位置4 * [T,T,T,F,F]
]
pad_2d_masks[i,j]的含义:
-
只有当位置和位置都是有效token时,才为 True -
即:两个padding位置之间不能有注意力连接
得到综合注意力矩阵
att_2d_masks = att_2d_masks & pad_2d_masks
既要满足因果关系,又要都是有效token
Flowmatching
从随机噪声生成50步机器人动作序列,使用flow matching的你 过程。
defsample_actions(
self, images, img_masks, lang_tokens, lang_masks, state, noise=None
) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" bsize = state.shape[0] device = state.device dtype = state.dtype if noise isNone: actions_shape = ( bsize, self.config.n_action_steps, self.config.max_action_dim, ) noise = torch.randn(actions_shape, device=device, dtype=dtype) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( images, img_masks, lang_tokens, lang_masks ) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 # Compute image and language key value cache _, past_key_values = self.paligemma_with_expert.forward( attention_mask=prefix_att_2d_masks, position_ids=prefix_position_ids, past_key_values=None, inputs_embeds=[prefix_embs, None], use_cache=self.config.use_cache, fill_kv_cache=True, ) dt = torch.tensor(-1.0 / self.config.num_steps, dtype=dtype, device=device) x_t = noise time = torch.tensor(1.0, dtype=dtype, device=device) while time >= -dt / 2: expanded_time = time.expand(bsize) v_t = self.predict_velocity( state, prefix_pad_masks, past_key_values, x_t, expanded_time ) # Euler step x_t += dt * v_t time += dt return x_tdefpredict_velocity(self, state, prefix_pad_masks, past_key_values, x_t, timestep): """predict velocity at time t using the suffix model.""" suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix( state, x_t, timestep ) suffix_len = suffix_pad_masks.shape[1] batch_size = prefix_pad_masks.shape[0] prefix_len = prefix_pad_masks.shape[1] prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand( batch_size, suffix_len, prefix_len ) suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 outputs_embeds, _ = self.paligemma_with_expert.forward( attention_mask=full_att_2d_masks, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=[None, suffix_embs], use_cache=self.config.use_cache, fill_kv_cache=False, ) suffix_out = outputs_embeds[1] suffix_out = suffix_out[:, -self.config.n_action_steps :] v_t = self.action_out_proj(suffix_out) return v_t
KV缓存在双向注意力中的数学原理
分块注意力机制的数学公式
我们回顾注意力机制(Self-Attention)的数学公式:
其中:
-
:查询矩阵 -
:键矩阵 -
:值矩阵
在中,序列结构可以被表示为:
其中:
-
:图像+语言tokens -
:状态+动作tokens
根据序列结构,对注意力进行分块表示
此时,注意力分数矩阵可以写为:
考虑注意力掩码矩阵
符号含义:
-
:prefix序列长度 = 图像token数量 + 语言token数量 -
:suffix序列长度 = 1个状态token + 50个动作token = 51
矩阵维度的具体解释:
**** :
-
全1矩阵,大小为 -
表示prefix内部的双向注意力 -
所有图像token和语言token之间可以相互attend
**** :
-
全0矩阵,大小为 -
表示prefix不能attend suffix -
图像和语言token不能看到状态和动作token
**** :
-
全1矩阵,大小为 -
表示suffix可以attend prefix -
状态和动作token可以看到所有图像和语言token
**** :
-
suffix内部的注意力掩码,大小为 -
由 make_att_2d_masks函数根据因果规则生成 -
控制状态token和动作token之间的注意力模式
π0典型的序列长度:
-
:取决于图像patch数量和语言token数量 -
:1个状态token + 50个动作token
掩码后的注意力分数为
数学定义:
对于两个相同维度的矩阵 和 :
即对应位置的元素相乘。
在注意力掩码中的应用:
在之前的公式中:
含义:
-
:原始注意力分数矩阵 -
:注意力掩码矩阵(True/False 或 1/0) -
:应用掩码后的注意力分数
具体操作:
-
掩码为1的位置:保留原始注意力分数 -
掩码为0的位置:注意力分数变为0(阻止注意力连接)
示例:
这种操作实现了选择性的注意力屏蔽。
KV缓存的数学原理
prefix块的输出:
prefix输出的独立性:
即prefix的输出仅依赖于prefix自身,与suffix无关。
suffix块的输出:
suffix输出的依赖性:
即suffix的输出依赖于prefix和suffix。
缓存机制的数学实现
第一次前向传播
计算并缓存prefix的KV:
计算prefix的注意力输出:
后续前向传播()
重用缓存的prefix KV:
计算当前步的suffix KV:
拼接完整的KV矩阵:
计算suffix的注意力输出:
双向注意力兼容性的数学证明
依赖关系的单向性
prefix的计算不变性:
这是因为:
suffix的动态依赖:
计算复杂度分析
无缓存的复杂度:
有缓存的复杂度:
-
首次: -
后续每步:
节省的计算量:
数学等价性验证
完整计算:
缓存计算:
等价性:
因为且对所有成立。
Flowmatching的原理
π0中的时间参数定义
π0的插值公式(与标准相反):
其中:
-
:插值参数 -
:真实动作序列 -
:纯噪声 -
:30%噪声 + 70%真实动作
Flow Matching的速度场
速度场定义:
逆向积分的ODE:
推理过程:从噪声到真实数据
Step1:初始化噪声
if noise isNone:
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
noise = torch.randn(actions_shape, device=device, dtype=dtype)
生成原始噪音数据
矩阵维度:(B, 50, 18)
Step2:Prefix处理和KV缓存优化
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Compute image and language key value cache
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=self.config.use_cache,
fill_kv_cache=True,
)
这一部分的代码在之前都进行详细拆解过,在此略过
Step3:Flow Macthing积分循环
dt = torch.tensor(-1.0 / self.config.num_steps, dtype=dtype, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=dtype, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.predict_velocity(state, prefix_pad_masks, past_key_values, x_t, expanded_time)
# Euler step
x_t += dt * v_t
time += dt
积分参数:
欧拉积分法:
如何求解这个速度,在下一个模块详细解释
积分路径:
对应的数据变化:
终止条件:
这确保最后一步能完整执行到 。
predict_velocity函数详细解析
给定当前状态 和时间 ,预测速度场 。
Step1:生成suffix嵌入
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
state, x_t, timestep
)
-
suffix_embs: :1个状态token + 50个动作token的嵌入 -
= proj_width= 1024 (action_expert的嵌入维度) -
suffix_pad_masks: :padding掩码,由于不存在padding,故全为True -
suffix_att_masks: :因果掩码,前2个为True,其余为False
Step2:提取维度信息
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
-
suffix_len = 51- suffix序列长度 -
batch_size = B- 批次大小 -
prefix_len = L_p- prefix序列长度(图像+语言token数量)
Step3:构建prefix到suffix的注意力掩码
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
batch_size, suffix_len, prefix_len
)
数学含义:
即:suffix的每个token(行)都可以attend prefix的所有有效token(列)。
可视化(假设prefix有5个token,suffix有3个token):
prefix_pad_masks = [True, True, True, False, False]
prefix_pad_2d_masks = [
[True, True, True, False, False], # suffix token 0
[True, True, True, False, False], # suffix token 1
[True, True, True, False, False], # suffix token 2
]
为什么这样设计?
-
suffix的每个token需要知道prefix中哪些是有效token -
通过复制prefix的padding掩码,确保suffix只attend有效的prefix token -
这是一个单向的跨注意力掩码
Step4:生成suffix内部的注意力掩码
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
回顾输入:
suffix_att_masks = [True, True, False, False, ..., False] # 长度51
# 状态 动作0 动作1-49
make_att_2d_masks的计算:
生成的掩码矩阵:
j=0 j=1 j=2 ... j=50
i=0(状态) T F F ... F # 状态只能attend自己
i=1(动作0) T T T ... T # 动作可以attend所有
i=2(动作1) T T T ... T
...
i=50(动作49) T T T ... T
输出维度:
Step5:拼接完整的注意力掩码
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
拼接操作:
维度变化:
prefix_pad_2d_masks.shape = (B, 51, L_p)
suffix_att_2d_masks.shape = (B, 51, 51)
# 在第2维(列维度)拼接
full_att_2d_masks.shape = (B, 51, L_p + 51)
完整掩码矩阵的结构:
[prefix tokens: L_p] [suffix tokens: 51]
状态 [ 全部可见 ] [ 只能看自己 ]
动作0 [ 全部可见 ] [ 全部可见 ]
动作1 [ 全部可见 ] [ 全部可见 ]
...
动作49 [ 全部可见 ] [ 全部可见 ]
数学表示:
Step6:计算位置编码
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
-
prefix_offsets只统计有效token,自动跳过padding -
cumsum确保suffix内部的位置是连续的
示例:
# Prefix的padding掩码 (B=1, L_p=10)
prefix_pad_masks = torch.tensor([ [True, True, True, True, True, True, True, False, False, False]])# 有效tokens: 7个 padding: 3个# Suffix的padding掩码 (B=1, L_s=6)suffix_pad_masks = torch.tensor([ [True, True, True, True, True, True]])# 状态 + 5个动作,全部有效prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]# prefix_offsets = torch.tensor([[7]]) # shape: (1, 1)torch.cumsum(suffix_pad_masks, dim=1)# cumsum = torch.tensor([[1, 2, 3, 4, 5, 6]]) # shape: (1, 6)# Prefix的有效位置:0-6# position_ids = torch.tensor([[7, 8, 9, 10, 11, 12]]) # shape: (1, 6)
Step7:Transformer前向传播
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=self.config.use_cache,
fill_kv_cache=False,
)
-
attention_mask=full_att_2d_masks: -
形状: -
控制 suffix_tokens的注意力模式 -
position_ids=position_ids: -
形状: -
suffix tokens的位置编码 -
past_key_values=past_key_values: -
缓存的prefix KV矩阵 -
包含: -
inputs_embeds=[None, suffix_embs]: -
None:表示不计算prefix的新嵌入 -
suffix_embs:当前步的suffix嵌入,形状 -
use_cache=True: 保持KV缓存 -
fill_kv_cache=False: 不更新prefix的KV缓存
上述代码使用KV缓存优化了inputs_embeds = [prefix_embs, suffix_embs]步骤
-
不需要重新计算prefix的嵌入和KV
-
根据完整的注意力计算公式
其中:
直接从缓存中调用数据
-
:从 past_key_values获取 -
:从 suffix_embs计算
Step8:提取suffix输出
suffix_out = outputs_embeds[1]
-
outputs_embeds是一个列表:[prefix_output, suffix_output] -
我们只需要suffix的输出(以token形式存在)
Step9:提取动作tokens的输出
suffix_out = suffix_out[:, -self.config.n_action_steps :]
切片操作:
-
-self.config.n_action_steps=-50 -
提取最后50个token(动作tokens) -
丢弃第一个token(状态token)
维度变化:
为什么丢弃状态token?
-
状态token只用于提供上下文信息 -
速度场只需要从动作tokens预测 -
对应50步动作序列
Step10:投影到动作空间
v_t = self.action_out_proj(suffix_out)
维度变化:
输出含义:
Step11:返回速度场
return v_t
最终输出:
物理含义:
-
在当前状态 和时间 下 -
动作序列应该朝哪个方向"移动" -
这个方向由神经网络学习得到