Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions skills/implement-method/SKILL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
---
name: implement-method
description: Implements a new method (scalarizer or aggregator) in TorchJD, starting from the research produced by the research-method skill and following the established file-by-file conventions. Use when a contributor wants to add the actual implementation of a scalarizer or aggregator that has already been investigated and listed in the tracking issues.
---

# Implement new method

This skill implements a new method by recovering its research, comparing the paper against the
existing implementations, settling the non-standard parts of its interface, and producing the full
set of TorchJD files (class, docs, tests, changelog) that match the established conventions.

It is the companion of the `research-method` skill: that one investigates a method and records a row
in a tracking issue; this one turns that row into a merged implementation.

**For agents:** invoke as `/implement-method method-name (paper-name)` (e.g.
`/implement-method stch (Smooth Tchebycheff Scalarization for Multi-objective Optimization)`).
If no method name is provided, ask the user for the name of the method and the title of the paper.

**For humans:** follow the numbered steps below to guide your development.

---

## Instructions

### Step 1: Recover the research context

Determine whether the method should be a **scalarizer** or an **aggregator**, then read everything
the `research-method` skill already found about it:

- Scalarizers are tracked in https://github.com/SimplexLab/TorchJD/issues/667, aggregators in
https://github.com/SimplexLab/TorchJD/issues/665. Fetch the relevant issue and find the row for
this method. Read every column:
- **Ref** — the paper (open it; you will need the exact equations / algorithm).
- **Stateful** — whether and how the method holds state.
- **Existing implementations** — links to the official repo (if any) and the best-known
third-party ones (LibMTL, libmoon, pymoo, ...), ideally with the exact file(s) and line(s).
- **Special Remarks** — may link to a full research write-up (e.g. a `claude.ai` share produced by
`research-method`). Read it if present.
- The most valuable inputs are the **non-standard interface aspects** uncovered during research:
statefulness, trainable parameters, randomness, warm-up / history buffers, statistics beyond the
`forward` values (e.g. per-task losses for an aggregator), and preconditions. If these are not
fully captured in the issue, **ask the user to share the `research-method` findings** before
continuing. Do not guess them.

If the method is not in the tracking issue yet, run `research-method` first.

### Step 2: Load the implementation reference for this method type

Read only the reference matching the method type, to keep context focused:

- **Scalarizer** → read `references/scalarizers.md`.
- **Aggregator** → read `references/aggregators.md`.

Each reference lists the exact files to create/edit and the TorchJD-specific conventions, with the
closest existing methods to mirror.

### Step 3: Compare the paper with the existing implementations

Always do this — it is the step we invariably end up needing. Read the relevant equations / the
algorithm box in the paper, then read the official and best-known third-party implementations at the
exact files/lines from the tracking row.

Reconcile any discrepancies between them. The ones that most often bite:

- **Minimization vs maximization.** TorchJD minimizes losses; much MOO/evolutionary work is written
for maximization, with the minimization form buried in a footnote. Find it, and check the sign of
every reference / ideal-point subtraction.
- **Normalization.** A direction or weight vector may be normalized (`w / ‖w‖`) in the code but not
the paper, or vice versa.
- **Dead arguments.** An impl may accept a parameter (e.g. a reference point) yet silently ignore it.
- **Droppable terms.** An `abs` / `clamp` / `max(0, ·)` in the paper may be unnecessary under the
method's preconditions (e.g. non-negative weights); drop it only with a justification.
- **Other:** an extra factor, an init value, a stabilization / epsilon trick.

Decide which to follow, note **why**, and surface the disagreement to the user — the implementation
should be faithful to a clearly-stated source, not an unexplained blend.

### Step 4: Settle the interface and design decisions

Using the research findings (Step 1) and the comparison (Step 3), map each non-standard aspect onto
the closest existing pattern from the reference loaded in Step 2 (statefulness, trainable parameters,
an internal optimizer, a preference/reference vector, ...). Then settle, for any method type:

- **Preconditions** (e.g. positivity): enforce them (raise `ValueError`) or only document them, and
how `nan`/`inf` should propagate.
- Which constructor arguments are **required vs optional**, and their **defaults**.

List the non-standard parts and your proposed handling, and **confirm the design with the user
before writing code.** This is where most of the maintainer review happens, so settle it up front.

### Step 5: Implement the method

Follow the file-by-file checklist in the reference loaded at Step 2. Match the style, naming, and
conventions of the closest existing method. If you adapt code from a third-party repository, add the
license header to the source file and an entry to `NOTICES` (see the reference).

### Step 6: Verify

Run the checks listed in the reference (unit tests with `-W error`, lint, and the docs
build/doctest). GPU tests require a CUDA device; if you cannot run them, provide the exact commands
for the user to run on their GPU and report back the results.

### Step 7: Self-review the code you produced

Before opening anything, re-read your own diff against the requirements and improve what can be
improved. Check that:

- The class follows the closest existing method's conventions (the reference's checklist): correct
base class(es), `forward(self, values, /)` returning a 0-dim scalar, shape validation, `reset()`
for stateful methods, a correct `__repr__`, and the docstring conventions (`r"""` only with LaTeX,
`:class:` cross-ref, `.. math::` + bullet list, a usage doctest, `:param:` for each argument).
- The design decisions settled in Step 4 are actually reflected in the code, and any discrepancy
between the paper and the existing implementations (Step 3) is resolved deliberately, with a
comment or docstring note where it is non-obvious.
- The tests cover the documented edge cases and contracts, not just the happy path.
- All six files are present and consistent (class, `__init__.py`, `.rst`, toctree, test,
`CHANGELOG.md`), plus `NOTICES` + a license header if you adapted code.

Apply the fixes you find, then re-run the relevant checks from Step 6.

### Step 8: Open a draft PR

Create a new branch, commit, and open a **draft** pull request targeting `main`, following the
repository's PR conventions (a `CHANGELOG.md` entry under `[Unreleased] > ### Added`; when asked for
a PR description, output raw GitHub-flavored markdown in a fenced code block, with GitHub math syntax
`$...$` / `$$...$$` and no em dashes). Keep it a draft so the contributor can read the code
themselves before requesting maintainer review. Return the PR URL when done.

---
108 changes: 108 additions & 0 deletions skills/implement-method/references/scalarizers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Reference: implementing a `Scalarizer`

A `Scalarizer` reduces a tensor of values of any shape into a single scalar — the baseline that
combines *losses* directly (a plain `loss.backward()` then gives the gradient), as opposed to an
`Aggregator` which combines per-loss *gradients*. Base class: `Scalarizer` in
`src/torchjd/scalarization/_scalarizer_base.py`.

**Don't work from this file alone — read the closest existing class end-to-end (its `_*.py` + `.rst`
+ `test_*.py`) and mirror it.** This reference is the map and the non-obvious rules, not a template.

## Contract for the subclass

- Subclasses `Scalarizer` (an `nn.Module`); `forward(self, values: Tensor, /) -> Tensor` returns a
**0-dim** scalar.
- The parameter is named **`values`** (positional-only), not `losses` — `Scalarizer` is generic
(maintainer decision). Accepts **any shape** and reduces over all elements (flatten if needed).

## Files to create / edit (new scalarizer `Foo`)

1. `src/torchjd/scalarization/_foo.py` — the class.
2. `src/torchjd/scalarization/__init__.py` — add the import + the `__all__` entry.
3. `docs/source/docs/scalarization/foo.rst` — doc page (mirror `geometric_mean.rst`).
4. `docs/source/docs/scalarization/index.rst` — add `foo.rst` to the `.. toctree::`.
5. `tests/unit/scalarization/test_foo.py` — tests.
6. `CHANGELOG.md` — entry under `[Unreleased] > ### Added`.
7. *(Only if you adapt third-party code)* license header in `_foo.py` + an entry in `NOTICES`.

## Pick the pattern and mirror it

| Pattern | Mirror | File |
|---|---|---|
| Stateless one-liner | `GeometricMean`, `Mean`, `Sum` | `_geometric_mean.py`, `_mean.py`, `_sum.py` |
| Stateless + preference/reference vector | `STCH`, `COSMOS`, `PBI` | `_stch.py`, `_cosmos.py`, `_pbi.py` |
| Stateful, trainable parameter | `UW`, `IMTL-L` | `_uw.py`, `_imtl_l.py` |
| Stateful, non-trainable history buffer | `DWA` | `_dwa.py` |
| Internal optimizer + multi-call protocol | `FAMO` | `_famo.py` |

### Pattern-specific rules (the things not obvious from one file)

- **Trainable** (`UW`/`IMTL-L`): also subclass `Stateful` (`from torchjd._mixins import Stateful`)
and implement `reset()`. State is an `nn.Parameter`, init to a neutral default (usually `0`), with
a `shape: int | Sequence[int]` arg (`Foo(3)` → `(3,)`). Validate `values.shape` at call time
(`ValueError`). The params are in `.parameters()`, so the user passes them to the optimizer — show
this in a doctest. A trained per-position param makes it **not** permutation-invariant; don't
assert it. Add a `shape`-aware `__repr__`.
- **History buffer** (`DWA`): **no** `nn.Parameter` (`list(Foo().parameters())` must be empty); hold
state in a `register_buffer` (moves with `.to()`, can be created lazily from the first input
shape). Provide an explicit update method (e.g. scheduler-like `step()`); `forward` **detaches**
weights derived from the state; `reset()` clears the buffer.
- **Internal optimizer / multi-call** (`FAMO`): private `nn.Parameter` (`_w`) with `.grad` cleared
after each step; a lazily-created internal `torch.optim.Adam`; an `update(new_losses)` method;
`forward` detaches the weights. Read `_famo.py` before copying.
- **Preference / reference vector** (`STCH`/`COSMOS`/`PBI`): validate shapes at call time
(`ValueError`, like `Constant`); flatten `weights`/`values`/`reference` in `forward`. `reference`
(z*) usually defaults to `0`; `weights` is required or uniform per the paper. Watch `nan`-gradient
footguns — `‖x‖` has a `0/0` grad at `0` (use `sqrt(‖x‖² + eps)`, see `PBI`); cosine needs an
eps-clamped denominator (use `torch.nn.functional.cosine_similarity`, see `COSMOS`). Lock with a
test.

## Docstring conventions

- Use a **raw** `r"""` docstring **only** if it contains LaTeX (`:math:` / `.. math::`) so
backslashes stay single; plain `"""` otherwise.
- Start with the `:class:` cross-ref(s) (`:class:`~torchjd.scalarization.Scalarizer``, plus
`:class:`~torchjd.Stateful`` if stateful); link the paper by full title + URL.
- Multi-symbol math → a `.. math::` block + a bullet list defining each symbol (not one dense inline
paragraph; see `STCH`). Document every `:param:`. Add a usage doctest (for stateful methods show
the optimizer / `step()` / `update()` cadence). Note preconditions in `.. note::` and decide
whether to enforce (`ValueError`) or let `nan`/`inf` propagate.

## Tests

Mirror `test_geometric_mean.py` (stateless) or `test_uw.py` (stateful). Shared infra in
`tests/unit/scalarization/`: `_inputs.py` (`shapes = [[], [5], [3, 4], [2, 3, 4]]`, `all_inputs`);
`_asserts.py` (`assert_returns_scalar`, `assert_grad_flow`, `assert_permutation_invariant`);
`utils.tensors` helpers (`tensor_`, `rand_`, `randn_`, `ones_`, `zeros_`, `randperm_` — they respect
`PYTEST_TORCH_DEVICE`/`PYTEST_TORCH_DTYPE`; for stateful instances make a `_foo(shape)` helper that
`.to(device=DEVICE, dtype=DTYPE)`, see `test_uw.py`). Cover: `test_value` (hand-checked),
`test_expected_structure` + `test_grad_flow` (parametrized over shapes), `test_permutation_invariant`
**only if** invariant, the documented edge cases/contracts (e.g. assert `nan` propagates on a bad
input so a future clamp can't slip in; a `does_not_raise()`/`raises(ValueError)` shape table; `reset`
clears state; params train / buffer rolls), and `test_representations`.

## CHANGELOG

`- Added `Foo` from [Paper Title](url) (Venue Year), a `Scalarizer` that <one-line description>.`

## Third-party attribution (only if adapting code, e.g. `FAMO`)

Header comment in `_foo.py`: `# Partly adapted from <url> — <License>, Copyright (c) <year>
<author>. # See NOTICES for the full license text.` plus the full license text in `NOTICES`.

## Verify (from repo root)

```bash
uv run pytest tests/unit/scalarization -W error -v # new tests
uv run pytest tests/unit -W error # full unit regression
uv run ruff check && uv run ruff format --check # lint + format
uv run make doctest -C docs && uv run make clean -C docs && uv run make html -C docs
uv run pre-commit run --all-files
PYTEST_TORCH_DEVICE=cuda:0 uv run pytest tests/unit -W error # GPU (needs CUDA)
```

- If `uv run` re-syncs unexpectedly, prefix with `UV_NO_SYNC=1`. Docs build is strict (`-W -n`), so
an `.rst` title underline must match its title length.
- Treat CI as the source of truth. A pre-existing test unrelated to your change can fail by
a tiny float tolerance on other platforms; confirm your new tests pass and that nothing you
touched regressed, rather than chasing it.
Loading