Skip to content

[New feature] 支持多教师在线蒸馏(MOPD)#239

Open
doctorMcy wants to merge 2 commits into
modelscope:mainfrom
doctorMcy:feature_mopd
Open

[New feature] 支持多教师在线蒸馏(MOPD)#239
doctorMcy wants to merge 2 commits into
modelscope:mainfrom
doctorMcy:feature_mopd

Conversation

@doctorMcy

Copy link
Copy Markdown

PR type

  • Bug Fix
  • [ √ ] New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

支持多教师在线蒸馏,可以选择两种损失计算方式:
1.ctkd,通过词表映射解决师生模型是异构词表时无法拟合的问题,举例,文本“你好吗”在教师模型的词表中解码为["你"、"好"、"吗"],而在学生模型的词表中解码是["你好吗"],此时两者p_t[P_t("你")、P_t("好")、P_t("吗")]、p_s[P_s("你好吗")]的序列维度不一致,可以将教师模型的概率序列转换为p_t`=P_t("你好吗")=P_t("你")+σ·P_t("好")+σ²P_t("吗"),在进行kl'散度计算(https://arxiv.org/pdf/2605.21699)
2.gold,通过拟合教师模型和学生模型的原始文本匹配部分的概率序列来计算KL散度,无法匹配的部分则使用ULD计算

Experiment results

Paste your experiment result here(if needed).

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Cross-Tokenizer Knowledge Distillation (CTKD) and Universal Logit Distillation (GOLDLoss) implementations, along with NPU training recipes. The reviewer identified several critical issues, including incorrect tokenizer token passing between student and teacher models in the GOLD training script, potential AttributeError risks from direct .vocab access, device mismatch risks in the global projection matrix cache, incorrect vocabulary size usage in logit conversion, hardcoded GPU allocation logic, and unhandled callable batches in Ray mode.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread src/twinkle/loss/ctkd.py
Comment on lines +208 to +214
config_data = {
'student_vocab': self.student_tokenizer.vocab,
'teacher_vocabs': [tokenizer.vocab for tokenizer in self.teacher_tokenizer_group],
'max_length': self.max_length,
'beta': self.beta,
'gamma': self.gamma,
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

student_tokenizer 和 teacher_tokenizer 可能没有 .vocab 属性(特别是 Hugging Face 的 Fast Tokenizer),直接访问会抛出 AttributeError。应该使用标准接口 .get_vocab() 来获取词表。

        config_data = {\n            'student_vocab': self.student_tokenizer.get_vocab(),\n            'teacher_vocabs': [tokenizer.get_vocab() for tokenizer in self.teacher_tokenizer_group],\n            'max_length': self.max_length,\n            'beta': self.beta,\n            'gamma': self.gamma,\n        }

Comment thread src/twinkle/loss/gold.py
Comment on lines +73 to +76
if teacher_logits is None and teacher_topk_logprobs is not None and teacher_topk_indices is not None:
# Get vocabulary size from student logits
vocab_size = student_logits.size(-1)
batch_size, seq_len, topk = teacher_topk_logprobs.shape

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

在将教师模型的 top-k 对数概率转换为完整 logits 时,代码直接使用学生模型的词表大小 student_logits.size(-1) 作为教师模型的 vocab_size。在异构词表(Cross-Tokenizer)蒸馏实败场景下,师生模型的词表大小通常不同。如果教师模型的词表大于学生模型,scatter_ 操作会因为索引越界而崩溃;如果小于,则会创建错误形状的张量。应该使用 len(self.teacher_tokenizer)(如果可用)来获取教师模型的真实词表大小。

        if teacher_logits is None and teacher_topk_logprobs is not None and teacher_topk_indices is not None:\n            # Get vocabulary size from teacher tokenizer if available, otherwise fallback to student\n            vocab_size = len(self.teacher_tokenizer) if self.teacher_tokenizer is not None else student_logits.size(-1)\n            batch_size, seq_len, topk = teacher_topk_logprobs.shape

Comment on lines +231 to +234
teacher_response = teacher_sampler.sample(
batch,
SamplingParams(max_tokens=0, temperature=1.0, prompt_logprobs=64),
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

直接将学生模型分词器编码的 batch 传给 teacher_sampler.sample 是错误的。因为师生模型的分词器不同,教师模型会把学生的 Token ID 误认为是自己的 Token ID,从而计算出完全错误的 Logits。应该像 ctkd_npu.py 中一样,先使用 student_tokenizer 将 batch 解码为文本,再构造 Trajectory 传给 teacher_sampler,以便教师模型使用自己的分词器重新编码。

        teacher_inputs = []\n        for item in batch:\n            text = student_tokenizer.decode(item['input_ids'], skip_special_tokens=False)\n            teacher_inputs.append({'messages': [{'role': 'user', 'content': text}]})\n\n        teacher_response = teacher_sampler.sample(\n            teacher_inputs,\n            SamplingParams(max_tokens=0, temperature=1.0, prompt_logprobs=64),\n        )

Comment on lines +271 to +273
student_model.forward_backward(
inputs=input_data,
adapter_name=ADAPTER_NAME,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

在调用 student_model.forward_backward 时,inputs 被传入了 input_data。父与子模型的分词器不同,input_data 是由教师模型重新编码的 Token ID。学生模型应该接收使用学生模型分词器编码的数据(即 batch),否则会导致学生模型接收到错误的 Token ID,产然严重的越界错误或训练出完全淲乱的模型。应该将 inputs=input_data 改为 inputs=batch。

        student_model.forward_backward(\n            inputs=batch,\n            adapter_name=ADAPTER_NAME,

Comment on lines +223 to +225
for batch in dataloader:
if optim_step >= MAX_STEPS:
break

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

在 Ray 模式下,dataloader 返回的 batch 是一个可调用对象(callable),通常需要先调用它才能获取实际的数据。当前代码直接将 batch 传给 teacher_sampler.sample,这会导致运行时类型错误。应该像 ctkd_npu.py 中一样,增加对 callable(batch) 的判断和调用。

    for batch in dataloader:\n        if optim_step >= MAX_STEPS:\n            break\n        # 在 Ray 模式下,batch 是一个可调用对象,需要调用它来获取实际数据\n        if callable(batch):\n            batch = batch()

Comment thread src/twinkle/loss/ctkd.py
Comment on lines +173 to +180
if cache_key in _PROJECTION_MATRIX_CACHE:
# Load from cache
cached_data = _PROJECTION_MATRIX_CACHE[cache_key]
self.projection_matrices = cached_data['projection_matrices']
self.projection_student_indices_list = cached_data['projection_student_indices_list']
self.projection_teacher_indices_list = cached_data['projection_teacher_indices_list']
self.projection_values_list = cached_data['projection_values_list']
else:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

全局缓存 _PROJECTION_MATRIX_CACHE 直接存储了可能已经被移动到特定设备(self.device)的 Tensor 引用。如果同一进程中创建了多个 CTKDLoss 实例,且它们使用不同的 device(或者其中一个为 None),后创建的实例会直接从缓存中获取已绑定到前一个实例设备的 Tensor,从而导致跨设备操作的 RuntimeError。应该在缓存中仅保存 CPU 上的 Tensor,在从缓存加载时再克隆并移动到当前实例指定的 device。

        if cache_key in _PROJECTION_MATRIX_CACHE:\n            # Load from cache\n            cached_data = _PROJECTION_MATRIX_CACHE[cache_key]\n            self.projection_matrices = cached_data['projection_matrices']\n            self.projection_student_indices_list = [t.to(self.device) if self.device is not None else t.clone() for t in cached_data['projection_student_indices_list']]\n            self.projection_teacher_indices_list = [t.to(self.device) if self.device is not None else t.clone() for t in cached_data['projection_teacher_indices_list']]\n            self.projection_values_list = [t.to(self.device) if self.device is not None else t.clone() for t in cached_data['projection_values_list']]\n        else:

Comment thread cookbook/rl/mopd/ctkd_npu.py Outdated

MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 1))
SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 1))
NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS * 2 # Two teacher samplers need separate GPU resources

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

当 SHARED_TEACHER_GPUS 为 True 时,两个教师采样器会共享相同的 GPU 资源,此时所需的总 GPU 数量应为 MODEL_GPUS + SAMPLER_GPUS。然而,目前 NUM_GPUS 被爾编码为 MODEL_GPUS + SAMPLER_GPUS * 2,这会导致在共享模式下申请了多余的 GPU 资源,可能导致资源不足而初始化失败。应该将 SHARED_TEACHER_GPUS 的实现提前,并动态计算 NUM_GPUS。

SHARED_TEACHER_GPUS = bool(os.environ.get('SHARED_TEACHER_GPUS', False))\nNUM_GPUS = MODEL_GPUS + (SAMPLER_GPUS if SHARED_TEACHER_GPUS else SAMPLER_GPUS * 2)

Comment thread src/twinkle/loss/gold.py Outdated
Comment on lines +230 to +237
def decode_tokens(tokenizer, token_ids):
pieces = []
prev = ""
for k in range(len(token_ids)):
cur = tokenizer.decode(token_ids[:k + 1], skip_special_tokens=False)
pieces.append(cur[len(prev):])
prev = cur
return pieces

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

在 _compute_distillation_loss 中定义了 decode_tokens 函数,但在此方法内并未被调用(它在 _build_alignment_groups_from_ids 中被重新定义并使用)。建议删除这段冗余的死代码以保持代码整洁。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant