• tinker:是训练的 SDK,用户通过调用 API 可以发起训练的各种请求,底层的训练逻辑由 tinker 后端实现
  • tinker-cookbook:包含了实际 fine-tuning LLM 的例子,基于 Tinker API 构建,并且提供了通用的抽象层

Tinker

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
import tinker
service_client = tinker.ServiceClient()
training_client = service_client.create_lora_training_client(
  base_model="meta-llama/Llama-3.2-1B", rank=32,
)
training_client.forward_backward(...)
training_client.optim_step(...)
training_client.save_state(...)
training_client.load_state(...)

sampling_client = training_client.save_weights_and_get_sampling_client(name="my_model")
sampling_client.sample(...)

Sampling from an image

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import requests
import tinker
from transformers import AutoTokenizer
 
model_name = "Qwen/Qwen3-VL-30B-A3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
 
service_client = tinker.ServiceClient()
training_client = await service_client.create_lora_training_client_async(base_model=model_name, rank=32)
sampling_client = await training_client.save_weights_and_get_sampling_client_async(name="sampler")
 
# Grab an image and ask a question
image_data = requests.get("https://thinkingmachines.ai/blog/on-policy-distillation/images/chess.png").content
model_input = tinker.ModelInput(chunks=[
    tinker.types.EncodedTextChunk(tokens=tokenizer.encode("<|im_start|>user\n<|vision_start|>")),
    tinker.types.ImageChunk(data=image_data, format="png"),
    tinker.types.EncodedTextChunk(tokens=tokenizer.encode("<|vision_end|>What is this?<|im_end|>\n<|im_start|>assistant\n")),
])
 
result = await sampling_client.sample_async(prompt=model_input, num_samples=1, sampling_params=tinker.types.SamplingParams(max_tokens=100))
print(tokenizer.decode(result.sequences[0].tokens))

Loss Function

执行 forward_backward 的时候,指定 loss function

  • Input: forward_backward expects a certain set of input tensors, passed in via datum.loss_fn_inputs, which is a dict mapping str to either a numpy or torch tensor
  • Output: forward_backward returns a ForwardBackwardOutput, which has a set of output tensors in fwd_bwd_result.loss_fn_outputs
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import tinker
import torch
from tinker import TensorData
 
# Create training data with required inputs
datum = tinker.Datum(
    model_input=input_tokens,
    loss_fn_inputs={
        "target_tokens": TensorData.from_torch(torch.tensor(target_tokens)),
        "logprobs": TensorData.from_torch(torch.tensor(sampling_logprobs)),  # Reference logprobs
        "advantages": TensorData.from_torch(torch.tensor(advantages)),
    }
)
 
# Option 1: Use importance sampling REINFORCE
fwd_bwd_result = await training_client.forward_backward_async(
    [datum], loss_fn="importance_sampling"
)
 
# Option 2: Use PPO with clipping
fwd_bwd_result = await training_client.forward_backward_async(
    [datum], loss_fn="ppo"
)

SFT: Cross Entropy

For SL, we implement the standard cross-entropy loss (i.e., negative log-likelihood), which optimizes the policy $p_{\theta}$ ​ to maximize the log-probability of the tokens $x$:

$$\mathcal{L}(\theta) = -\mathbb{E}x \left[ \log p\theta(x) \right]$$ where weights is either 0 or 1, typically generated from renderer.build_supervised_example() which returns (model_input, weights) (i.e., to specify the desired assistant turns to train on).

1
2
3
4
# Apply weights and compute elementwise loss
elementwise_loss = -target_logprobs * weights
# Apply sum reduction to get the total loss
loss = elementwise_loss.sum()  # scalar
  • Input tensors:
    • target_tokens: array[(N,), int] - Target token IDs
    • weights: array[(N,), float] - Token-level loss weights (typically from the renderer)
  • Output tensors:
    • logprobs: array[(N,), float] - Log probabilities of predicted tokens
  • Output diagnostics:
    • loss:sum (scalar) - Sum of weighted cross-entropy losses

Policy Gradient: importance sampling

For RL, we implement a common variant of the policy gradient objective, used in practical settings where the learner policy $p$ may differ from the sampling policy $q$, which is common due to, e.g., non-determinism. The issue is that if these policies differ, then the objective: $$\mathcal{L}(\theta) = \mathbb{E}{x \sim p\theta} \left[ A(x) \right] $$ is not computed in an unbiased why due to $x \sim q$ (sampler) not exactly matching the desired $x \sim p_\theta$ (learner). To correct the bias, we use a modified “importance sampling” objective: $$ \mathcal{L}{\text{IS}}(\theta) = \mathbb{E}{x \sim q} \left[ \frac{p_\theta(x)}{q(x)} A(x) \right], $$ which yields the correct expected reward. In the formula above:

  • $\log p_\theta(x)$ – target_logprobs is from the learner, on the forward part of the forward_backward pass.
  • $\log q(x)$ – sampling_logprobs is from the sampler, recorded during sampling as a correction term.
1
2
3
4
# Compute probability ratio
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
# Compute importance-weighted loss
loss = -(prob_ratio * advantages).sum()
  • Input tensors: - target_tokens: array[(N,), int] - Target token IDs (from the sampler qq) - logprobs: array[(N,), float] - sampling_logprobs for the tokens - advantages: array[(N,), float] - Advantage values for RL (positive to reinforce, negative to discourage
  • Output tensors: - logprobs: array[(N,), float] - target_logprobs for the tokens
  • Output diagnostics:
    • loss:sum (scalar) - Sum of importance-weighted policy gradient losses

PPO

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# Compute probability ratio
prob_ratio = torch.exp(target_logprobs - sampling_logprobs)
# Apply clipping
clipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)
# Compute both objectives
unclipped_objective = prob_ratio * advantages
clipped_objective = clipped_ratio * advantages
# Take minimum (most conservative)
ppo_objective = torch.min(unclipped_objective, clipped_objective)
# PPO loss is negative of objective
loss = -ppo_objective.sum()

自定义 clip threshold

1
2
3
4
fwd_bwd_result = await training_client.forward_backward_async(     data=data,
   loss_fn="ppo", 
   loss_fn_config={"clip_low_threshold": 0.9, "clip_high_threshold": 1.1}
)

CISPO

DRO

Saving and loading weights

主要的 API

  1. save_weights_for_sampler(): saves a copy of the model weights that can be used for sampling.
  2. save_state(): saves the weights and the optimizer state. You can fully resume training from this checkpoint.
  3. load_state(): load the weights and the optimizer state. You can fully resume training from this checkpoint.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# Setup
import tinker
service_client = tinker.ServiceClient()
training_client = service_client.create_lora_training_client(
    base_model="meta-llama/Llama-3.2-1B", rank=32
)
 
# Save a checkpoint that you can use for sampling
sampling_path = training_client.save_weights_for_sampler(name="0000").result().path
 
# Create a sampling client with that checkpoint
sampling_client = service_client.create_sampling_client(model_path=sampling_path) #

Downloading Weights

1
2
3
4
rest_client = service_client.create_rest_client()
future = rest_client.get_checkpoint_archive_url_from_tinker_path(sampling_client.model_path)
with open(f"model-checkpoint.tar.gz", "wb") as f:
    f.write(future.result())

Tinker Cookbook

参考资料