Skip to content

qten.linalg

Package reference for qten.linalg.

linalg

Linear-algebra routines built on top of QTen tensors.

This package contains decomposition algorithms and tensor-aware numerical helpers that operate on Tensor objects while preserving symbolic dimension metadata.

Decompositions
  • eig General eigendecomposition for square tensor-valued matrices.
  • eigh Hermitian eigendecomposition.
  • eigvals General eigenvalues only.
  • eigvalsh Hermitian eigenvalues only.
  • qr QR factorization.
  • svd Singular value decomposition.
Convenience re-exports

Exported API

eig

eig(tensor: Tensor) -> EigH

Perform eigendecomposition on general square matrix axes.

This function applies torch.linalg.eig to the final two dimensions of a Tensor. The last two dimensions must span the same Hilbert space up to ray ordering so they can be interpreted as a square operator. Any leading dimensions are treated as batch dimensions and are preserved in both outputs.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor whose last two dimensions form square matrices.

required

Returns:

Type Description
EigH

EigH containing: - eigenvalues, whose dtype is the complex dtype associated with the input and whose dims replace the matrix axes with one IndexSpace. - eigenvectors, whose dtype matches that complex dtype and whose dims are the leading batch dims followed by (row_dim, spectrum).

Examples:

result = eig(tensor)
values = result.eigenvalues
vectors = result.eigenvectors
Notes

torch.linalg.eig does not guarantee any ordering of the eigenvalues. This function sorts eigenvalues lexicographically by (real, imag) and applies the same reordering to eigenvectors.

The returned tensors satisfy \(A V = V\Lambda\), where \(\Lambda\) is the diagonal matrix of eigenvalues. If the input matrix is diagonalizable, this gives the reconstruction \(A = V\Lambda V^{-1}\). In code, \(V\) is eigenvectors and \(\Lambda\) is the diagonal matrix built from eigenvalues.

Source code in src/qten/linalg/decompose.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def eig(tensor: Tensor) -> EigH:
    r"""
    Perform eigendecomposition on general square matrix axes.

    This function applies [`torch.linalg.eig`](https://pytorch.org/docs/stable/generated/torch.linalg.eig.html)
    to the final two dimensions of a
    [`Tensor`][qten.linalg.tensors.Tensor]. The last two dimensions must span
    the same Hilbert space up to ray ordering so they can be interpreted as a
    square operator. Any leading dimensions are treated as batch dimensions and
    are preserved in both outputs.

    Parameters
    ----------
    tensor : Tensor
        Input tensor whose last two dimensions form square matrices.

    Returns
    -------
    EigH
        [`EigH`][qten.linalg.decompose.EigH] containing:
        - `eigenvalues`, whose dtype is the complex dtype associated with the
          input and whose dims replace the matrix axes with one
          [`IndexSpace`][qten.symbolics.state_space.IndexSpace].
        - `eigenvectors`, whose dtype matches that complex dtype and whose dims
          are the leading batch dims followed by `(row_dim, spectrum)`.

    Examples
    --------
    ```python
    result = eig(tensor)
    values = result.eigenvalues
    vectors = result.eigenvectors
    ```

    Notes
    -----
    `torch.linalg.eig` does not guarantee any ordering of the eigenvalues. This
    function sorts eigenvalues lexicographically by `(real, imag)` and applies
    the same reordering to eigenvectors.

    The returned tensors satisfy \(A V = V\Lambda\), where \(\Lambda\) is the
    diagonal matrix of eigenvalues. If the input matrix is diagonalizable, this
    gives the reconstruction \(A = V\Lambda V^{-1}\). In code, \(V\) is
    `eigenvectors` and \(\Lambda\) is the diagonal matrix built from
    `eigenvalues`.
    """
    _assert_eig_dims(tensor)

    dim0 = tensor.dims[-2]
    target = tensor.align(-1, dim0)  # Align column space to match the row space
    eigenvalues, eigenvectors = torch.linalg.eig(target.data)
    eigenvalues, eigenvectors = _sort_eigenpairs(eigenvalues, eigenvectors)

    spectrum = IndexSpace.linear(eigenvalues.shape[-1])

    eigvals = Tensor(
        data=eigenvalues,
        dims=target.dims[:-2] + (spectrum,),
    )
    eigvecs = Tensor(
        data=eigenvectors,
        dims=target.dims[:-2] + (dim0, spectrum),
    )

    return EigH(eigvals, eigvecs)

eigh

eigh(tensor: Tensor) -> EigH

Perform Hermitian eigendecomposition on the last two tensor dimensions.

This function applies torch.linalg.eigh to the matrix axes of a Tensor. The final two dimensions must span the same Hilbert space up to ray ordering so they can be interpreted as a Hermitian operator. Any leading dimensions are treated as batch dimensions and are preserved in both outputs.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor whose last two dimensions form Hermitian matrices.

required

Returns:

Type Description
EigH

EigH containing: - eigenvalues, whose dtype is the real dtype associated with the input and whose dims replace the matrix axes with one IndexSpace. - eigenvectors, whose dtype matches the input dtype and whose dims are the leading batch dims followed by (row_dim, spectrum).

Examples:

result = eigh(tensor)
eigenvalues = result.eigenvalues
eigenvectors = result.eigenvectors
Notes

torch.linalg.eigh is differentiable for Hermitian inputs, but the gradients can be ill-defined or unstable when eigenvalues are degenerate or nearly degenerate. If you use this in autograd, consider stabilizing the spectrum (e.g., with a small perturbation) or avoiding backpropagation through eigenvectors when bands are expected to merge.

The original matrix is recovered by forming a diagonal matrix from eigenvalues and evaluating \(V\Lambda V^\dagger\). In code, this is eigenvectors @ W @ eigenvectors.h(-2, -1).

Source code in src/qten/linalg/decompose.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def eigh(tensor: Tensor) -> EigH:
    r"""
    Perform Hermitian eigendecomposition on the last two tensor dimensions.

    This function applies [`torch.linalg.eigh`](https://pytorch.org/docs/stable/generated/torch.linalg.eigh.html)
    to the matrix axes of a [`Tensor`][qten.linalg.tensors.Tensor]. The final
    two dimensions must span the same Hilbert space up to ray ordering so they
    can be interpreted as a Hermitian operator. Any leading dimensions are
    treated as batch dimensions and are preserved in both outputs.

    Parameters
    ----------
    tensor : Tensor
        Input tensor whose last two dimensions form Hermitian matrices.

    Returns
    -------
    EigH
        [`EigH`][qten.linalg.decompose.EigH] containing:
        - `eigenvalues`, whose dtype is the real dtype associated with the
          input and whose dims replace the matrix axes with one
          [`IndexSpace`][qten.symbolics.state_space.IndexSpace].
        - `eigenvectors`, whose dtype matches the input dtype and whose dims
          are the leading batch dims followed by `(row_dim, spectrum)`.

    Examples
    --------
    ```python
    result = eigh(tensor)
    eigenvalues = result.eigenvalues
    eigenvectors = result.eigenvectors
    ```

    Notes
    -----
    `torch.linalg.eigh` is differentiable for Hermitian inputs, but the gradients
    can be ill-defined or unstable when eigenvalues are degenerate or nearly
    degenerate. If you use this in autograd, consider stabilizing the spectrum
    (e.g., with a small perturbation) or avoiding backpropagation through
    eigenvectors when bands are expected to merge.

    The original matrix is recovered by forming a diagonal matrix from
    `eigenvalues` and evaluating \(V\Lambda V^\dagger\). In code, this is
    `eigenvectors @ W @ eigenvectors.h(-2, -1)`.
    """
    _assert_eig_dims(tensor)

    dim0 = tensor.dims[-2]
    target = tensor.align(-1, dim0)  # Align column space to match the row space
    eigenvalues, eigenvectors = torch.linalg.eigh(target.data)

    spectrum = IndexSpace.linear(eigenvalues.shape[-1])

    eigvals = Tensor(
        data=eigenvalues,
        dims=target.dims[:-2] + (spectrum,),
    )
    eigvecs = Tensor(
        data=eigenvectors,
        dims=target.dims[:-2] + (dim0, spectrum),
    )

    return EigH(eigvals, eigvecs)

eigvalsh

eigvalsh(tensor: Tensor) -> Tensor

Compute Hermitian eigenvalues on the last two tensor dimensions.

This is the eigenvalues-only companion to eigh. The last two dimensions must span the same Hilbert space up to ray ordering and represent a Hermitian operator. Leading dimensions are treated as batch dimensions.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor whose last two dimensions form Hermitian matrices.

required

Returns:

Type Description
Tensor

Eigenvalues as a Tensor whose dtype matches the real dtype associated with the input and whose dims keep the leading batch dimensions while replacing the matrix axes with a single IndexSpace.

Examples:

values = eigvalsh(tensor)
Source code in src/qten/linalg/decompose.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def eigvalsh(tensor: Tensor) -> Tensor:
    """
    Compute Hermitian eigenvalues on the last two tensor dimensions.

    This is the eigenvalues-only companion to
    [`eigh`][qten.linalg.decompose.eigh]. The last two dimensions must span
    the same Hilbert space up to ray ordering and represent a Hermitian
    operator. Leading dimensions are treated as batch dimensions.

    Parameters
    ----------
    tensor : Tensor
        Input tensor whose last two dimensions form Hermitian matrices.

    Returns
    -------
    Tensor
        Eigenvalues as a [`Tensor`][qten.linalg.tensors.Tensor] whose dtype
        matches the real dtype associated with the input and whose dims keep
        the leading batch dimensions while replacing the matrix axes with a
        single [`IndexSpace`][qten.symbolics.state_space.IndexSpace].

    Examples
    --------
    ```python
    values = eigvalsh(tensor)
    ```
    """
    _assert_eig_dims(tensor)

    dim0 = tensor.dims[-2]
    target = tensor.align(-1, dim0)  # Align column space to match the row space
    eigenvalues = torch.linalg.eigvalsh(target.data)

    spectrum = IndexSpace.linear(eigenvalues.shape[-1])

    vals = Tensor(
        data=eigenvalues,
        dims=target.dims[:-2] + (spectrum,),
    )

    return vals

eigvals

eigvals(tensor: Tensor) -> Tensor

Compute eigenvalues of general square matrix axes.

This is the eigenvalues-only companion to eig. The last two dimensions must span the same Hilbert space up to ray ordering and represent a square operator. Leading dimensions are treated as batch dimensions.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor whose last two dimensions form square matrices.

required

Returns:

Type Description
Tensor

Eigenvalues as a Tensor whose dtype matches the complex dtype associated with the input and whose dims keep the leading batch dimensions while replacing the matrix axes with a single IndexSpace.

Notes

torch.linalg.eigvals does not guarantee any ordering of the eigenvalues. This function sorts eigenvalues lexicographically by (real, imag).

Examples:

values = eigvals(tensor)
Source code in src/qten/linalg/decompose.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def eigvals(tensor: Tensor) -> Tensor:
    """
    Compute eigenvalues of general square matrix axes.

    This is the eigenvalues-only companion to
    [`eig`][qten.linalg.decompose.eig]. The last two dimensions must span the
    same Hilbert space up to ray ordering and represent a square operator.
    Leading dimensions are treated as batch dimensions.

    Parameters
    ----------
    tensor : Tensor
        Input tensor whose last two dimensions form square matrices.

    Returns
    -------
    Tensor
        Eigenvalues as a [`Tensor`][qten.linalg.tensors.Tensor] whose dtype
        matches the complex dtype associated with the input and whose dims keep
        the leading batch dimensions while replacing the matrix axes with a
        single [`IndexSpace`][qten.symbolics.state_space.IndexSpace].

    Notes
    -----
    `torch.linalg.eigvals` does not guarantee any ordering of the eigenvalues.
    This function sorts eigenvalues lexicographically by `(real, imag)`.

    Examples
    --------
    ```python
    values = eigvals(tensor)
    ```
    """
    _assert_eig_dims(tensor)

    dim0 = tensor.dims[-2]
    target = tensor.align(-1, dim0)  # Align column space to match the row space
    eigenvalues = torch.linalg.eigvals(target.data)
    eigenvalues, _ = _sort_eigenpairs(eigenvalues)

    spectrum = IndexSpace.linear(eigenvalues.shape[-1])

    vals = Tensor(
        data=eigenvalues,
        dims=target.dims[:-2] + (spectrum,),
    )

    return vals

qr

qr(tensor: Tensor) -> QR

Perform reduced QR decomposition on the last two tensor dimensions.

This function applies torch.linalg.qr with mode="reduced" to the matrix axes of the input tensor. The last two dimensions may be rectangular. Any leading dimensions are treated as batch dimensions and are preserved in both outputs.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor whose last two dimensions form matrices.

required

Returns:

Type Description
QR

QR containing: - Q, a Tensor with orthonormal columns and dims (..., row_dim, factor). - R, an upper-triangular Tensor with dims (..., factor, col_dim).

Examples:

result = qr(tensor)
q = result.Q
r = result.R
Notes

The shared factor axis is represented by an IndexSpace whose size equals the reduced QR bond dimension. The original matrix is recovered as \(Q R\), via Q @ R in code.

Source code in src/qten/linalg/decompose.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
def qr(tensor: Tensor) -> QR:
    r"""
    Perform reduced QR decomposition on the last two tensor dimensions.

    This function applies [`torch.linalg.qr`](https://pytorch.org/docs/stable/generated/torch.linalg.qr.html)
    with `mode="reduced"` to the matrix axes of the input tensor. The last two
    dimensions may be rectangular. Any leading dimensions are treated as batch
    dimensions and are preserved in both outputs.

    Parameters
    ----------
    tensor : Tensor
        Input tensor whose last two dimensions form matrices.

    Returns
    -------
    QR
        [`QR`][qten.linalg.decompose.QR] containing:
        - `Q`, a [`Tensor`][qten.linalg.tensors.Tensor] with orthonormal
          columns and dims `(..., row_dim, factor)`.
        - `R`, an upper-triangular
          [`Tensor`][qten.linalg.tensors.Tensor] with dims
          `(..., factor, col_dim)`.

    Examples
    --------
    ```python
    result = qr(tensor)
    q = result.Q
    r = result.R
    ```

    Notes
    -----
    The shared `factor` axis is represented by an
    [`IndexSpace`][qten.symbolics.state_space.IndexSpace] whose size equals the
    reduced QR bond dimension. The original matrix is recovered as \(Q R\), via
    `Q @ R` in code.
    """
    if tensor.rank() < 2:
        raise ValueError(
            "Input tensor must have at least two dimensions for matrix decomposition."
        )

    row_dim = tensor.dims[-2]
    col_dim = tensor.dims[-1]

    q_data, r_data = torch.linalg.qr(tensor.data, mode="reduced")
    spectral_dim = IndexSpace.linear(q_data.shape[-1])

    q = Tensor(
        data=q_data,
        dims=tensor.dims[:-2] + (row_dim, spectral_dim),
    )
    r = Tensor(
        data=r_data,
        dims=tensor.dims[:-2] + (spectral_dim, col_dim),
    )

    return QR(q, r)

svd

svd(
    tensor: Tensor,
    values_as_matrix: bool = False,
    full_matrices: bool = False,
) -> SVD

Perform singular value decomposition on the last two tensor dimensions.

This function applies torch.linalg.svd to the matrix axes of the input tensor and returns symbolic dimensions that distinguish reduced and full factorizations. The last two dimensions may be rectangular. Any leading dimensions are treated as batch dimensions and are preserved in all outputs.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor whose last two dimensions form matrices.

required
values_as_matrix bool

If True, return singular values as an explicit diagonal matrix tensor. If False, return them as a vector on a single spectral axis.

`False`
full_matrices bool

If True, compute full-sized U and Vh. If False, compute the reduced SVD.

`False`

Returns:

Type Description
SVD

SVD containing: - U, with dims (..., row_dim, factor) for reduced SVD or (..., row_dim, left_factor) for full SVD. - S, with dims (..., factor) by default, (..., factor, factor) when values_as_matrix=True in reduced mode, or (..., left_factor, right_factor) in full matrix form. - Vh, with dims (..., factor, col_dim) for reduced SVD or (..., right_factor, col_dim) for full SVD.

Examples:

result = svd(tensor)
u = result.U
s = result.S
vh = result.Vh
Notes

In reduced mode, factor is the shared singular-value IndexSpace. In full mode, left_factor and right_factor are sized to the full row and column spaces of the input matrix axes. The original matrix is recovered as \(U\Sigma V^\dagger\), using U @ Sigma @ Vh in code. Here Sigma is either the returned S tensor (values_as_matrix=True) or the diagonal matrix formed from the returned singular-value vector.

Source code in src/qten/linalg/decompose.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
def svd(
    tensor: Tensor,
    values_as_matrix: bool = False,
    full_matrices: bool = False,
) -> SVD:
    r"""
    Perform singular value decomposition on the last two tensor dimensions.

    This function applies [`torch.linalg.svd`](https://pytorch.org/docs/stable/generated/torch.linalg.svd.html)
    to the matrix axes of the input tensor and returns symbolic dimensions that
    distinguish reduced and full factorizations. The last two dimensions may be
    rectangular. Any leading dimensions are treated as batch dimensions and are
    preserved in all outputs.

    Parameters
    ----------
    tensor : Tensor
        Input tensor whose last two dimensions form matrices.
    values_as_matrix : bool, default `False`
        If `True`, return singular values as an explicit diagonal matrix
        tensor. If `False`, return them as a vector on a single spectral axis.
    full_matrices : bool, default `False`
        If `True`, compute full-sized `U` and `Vh`. If `False`, compute the
        reduced SVD.

    Returns
    -------
    SVD
        [`SVD`][qten.linalg.decompose.SVD] containing:
        - `U`, with dims `(..., row_dim, factor)` for reduced SVD or
          `(..., row_dim, left_factor)` for full SVD.
        - `S`, with dims `(..., factor)` by default, `(..., factor, factor)`
          when `values_as_matrix=True` in reduced mode, or
          `(..., left_factor, right_factor)` in full matrix form.
        - `Vh`, with dims `(..., factor, col_dim)` for reduced SVD or
          `(..., right_factor, col_dim)` for full SVD.

    Examples
    --------
    ```python
    result = svd(tensor)
    u = result.U
    s = result.S
    vh = result.Vh
    ```

    Notes
    -----
    In reduced mode, `factor` is the shared singular-value
    [`IndexSpace`][qten.symbolics.state_space.IndexSpace]. In full mode,
    `left_factor` and `right_factor` are sized to the full row and column
    spaces of the input matrix axes. The original matrix is recovered as
    \(U\Sigma V^\dagger\), using `U @ Sigma @ Vh` in code. Here `Sigma` is
    either the returned `S` tensor (`values_as_matrix=True`) or the diagonal
    matrix formed from the returned singular-value vector.
    """
    if tensor.rank() < 2:
        raise ValueError(
            "Input tensor must have at least two dimensions for matrix decomposition."
        )

    row_dim = tensor.dims[-2]
    col_dim = tensor.dims[-1]

    u_data, s_data, vh_data = torch.linalg.svd(tensor.data, full_matrices=full_matrices)

    factor = IndexSpace.linear(s_data.shape[-1])

    if full_matrices:
        left_factor = IndexSpace.linear(row_dim.dim)
        right_factor = IndexSpace.linear(col_dim.dim)
        u = Tensor(
            data=u_data,
            dims=tensor.dims[:-2] + (row_dim, left_factor),
        )
    else:
        u = Tensor(
            data=u_data,
            dims=tensor.dims[:-2] + (row_dim, factor),
        )
    if values_as_matrix:
        if full_matrices:
            k = s_data.shape[-1]
            s_mat = torch.zeros(
                *s_data.shape[:-1],
                left_factor.dim,
                right_factor.dim,
                dtype=s_data.dtype,
                device=s_data.device,
            )
            diag = torch.diag_embed(s_data)
            s_mat[..., :k, :k] = diag
            s = Tensor(
                data=s_mat,
                dims=tensor.dims[:-2] + (left_factor, right_factor),
            )
        else:
            s_mat = torch.diag_embed(s_data)
            s = Tensor(
                data=s_mat,
                dims=tensor.dims[:-2] + (factor, factor),
            )
    else:
        s = Tensor(
            data=s_data,
            dims=tensor.dims[:-2] + (factor,),
        )
    if full_matrices:
        vh = Tensor(
            data=vh_data,
            dims=tensor.dims[:-2] + (right_factor, col_dim),
        )
    else:
        vh = Tensor(
            data=vh_data,
            dims=tensor.dims[:-2] + (factor, col_dim),
        )

    return SVD(u, s, vh)

norm

norm(
    tensor: TensorType,
    ord: int | float | str | None = None,
    dim: int | tuple[int, int] | None = None,
) -> TensorType

Compute a vector or matrix norm with metadata-aware dimension reduction.

This forwards to torch.linalg.norm for the numeric computation, then removes the reduced axes from the symbolic output dims.

See Also

torch.linalg.norm Official PyTorch reference for the underlying numeric operation. torch.linalg.vector_norm Clearer vector-only norm API in PyTorch. torch.linalg.matrix_norm Clearer matrix-only norm API in PyTorch.

Behavior

The interpretation of ord depends on dim:

  • dim is an int: compute a vector norm along that axis.
  • dim is a 2-tuple: compute a matrix norm over those two axes.
  • dim is None: follow PyTorch's torch.linalg.norm rules. In particular, ord=None flattens the tensor and computes a vector 2-norm, while ord != None expects PyTorch's documented 1D/2D behavior.
Supported ord values

Vector-norm forms (dim is an int) - None - 0 - any finite int or float - float("inf") - -float("inf")

Matrix-norm forms (dim is a 2-tuple) - None - "fro" - "nuc" - 1, -1 - 2, -2 - float("inf") - -float("inf")

Parameters:

Name Type Description Default
tensor Tensor

The tensor to reduce.

required
ord Optional[Union[int, float, str]]

Order of the norm forwarded to torch.linalg.norm.

Common examples: - ord=2 for the Euclidean vector norm or spectral matrix norm - ord=1 for an L1 vector norm or induced 1 matrix norm - ord=float("inf") for max-based norms - ord="fro" for the Frobenius matrix norm - ord="nuc" for the nuclear matrix norm

None
dim Optional[Union[int, Tuple[int, int]]]

Reduction axis or axes.

  • int: vector norm
  • Tuple[int, int]: matrix norm
  • None: use PyTorch's default torch.linalg.norm behavior
None

Returns:

Type Description
TensorType

Tensor containing the requested norm values with reduced axes removed from dims.

Raises:

Type Description
IndexError

If any requested reduction axis is out of range for the tensor rank.

ValueError

If dim contains duplicate axes.

Source code in src/qten/linalg/tensors.py
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
def norm(
    tensor: TensorType,
    ord: Optional[Union[int, float, str]] = None,
    dim: Optional[Union[int, Tuple[int, int]]] = None,
) -> TensorType:
    """
    Compute a vector or matrix norm with metadata-aware dimension reduction.

    This forwards to `torch.linalg.norm` for the numeric computation, then
    removes the reduced axes from the symbolic output dims.

    See Also
    --------
    [`torch.linalg.norm`](https://docs.pytorch.org/docs/stable/generated/torch.linalg.norm.html)
        Official PyTorch reference for the underlying numeric operation.
    [`torch.linalg.vector_norm`](https://docs.pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html)
        Clearer vector-only norm API in PyTorch.
    [`torch.linalg.matrix_norm`](https://docs.pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html)
        Clearer matrix-only norm API in PyTorch.

    Behavior
    --------
    The interpretation of `ord` depends on `dim`:

    - `dim` is an `int`: compute a vector norm along that axis.
    - `dim` is a 2-tuple: compute a matrix norm over those two axes.
    - `dim is None`: follow PyTorch's `torch.linalg.norm` rules. In
      particular, `ord=None` flattens the tensor and computes a vector 2-norm,
      while `ord != None` expects PyTorch's documented 1D/2D behavior.

    Supported `ord` values
    ----------------------
    Vector-norm forms (`dim` is an `int`)
    - `None`
    - `0`
    - any finite `int` or `float`
    - `float("inf")`
    - `-float("inf")`

    Matrix-norm forms (`dim` is a 2-tuple)
    - `None`
    - `"fro"`
    - `"nuc"`
    - `1`, `-1`
    - `2`, `-2`
    - `float("inf")`
    - `-float("inf")`

    Parameters
    ----------
    tensor : Tensor
        The tensor to reduce.
    ord : Optional[Union[int, float, str]], optional
        Order of the norm forwarded to `torch.linalg.norm`.

        Common examples:
        - `ord=2` for the Euclidean vector norm or spectral matrix norm
        - `ord=1` for an L1 vector norm or induced 1 matrix norm
        - `ord=float("inf")` for max-based norms
        - `ord="fro"` for the Frobenius matrix norm
        - `ord="nuc"` for the nuclear matrix norm
    dim : Optional[Union[int, Tuple[int, int]]], optional
        Reduction axis or axes.

        - `int`: vector norm
        - `Tuple[int, int]`: matrix norm
        - `None`: use PyTorch's default `torch.linalg.norm` behavior

    Returns
    -------
    TensorType
        Tensor containing the requested norm values with reduced axes removed
        from `dims`.

    Raises
    ------
    IndexError
        If any requested reduction axis is out of range for the tensor rank.
    ValueError
        If `dim` contains duplicate axes.
    """
    reduced = torch.linalg.norm(tensor.data, ord=ord, dim=dim)
    if dim is None:
        return replace(tensor, data=reduced, dims=())

    rank_ = tensor.rank()
    dims_tuple: Tuple[int, ...]
    if isinstance(dim, int):
        dims_tuple = (dim,)
    else:
        dims_tuple = dim

    normalized_dims: list[int] = []
    for d in dims_tuple:
        nd = d
        if nd < 0:
            nd += rank_
        if nd < 0 or nd >= rank_:
            raise IndexError(f"Dimension index {d} out of range for rank {rank_}")
        if nd in normalized_dims:
            raise ValueError("norm dim entries must be unique")
        normalized_dims.append(nd)

    reduced_dims_set = set(normalized_dims)
    new_dims = tuple(
        current_dim
        for idx, current_dim in enumerate(tensor.dims)
        if idx not in reduced_dims_set
    )
    return replace(tensor, data=reduced, dims=new_dims)