Skip to content

qten.linalg.decompose

Module reference for qten.linalg.decompose.

decompose

Tensor-aware matrix decomposition routines for QTen.

This module wraps PyTorch's dense linear-algebra decompositions so they operate on Tensor objects while preserving symbolic dimension metadata.

Public decompositions
  • eigh Hermitian eigendecomposition returning EigH.
  • eigvalsh Hermitian eigenvalues only.
  • eig General eigendecomposition returning EigH.
  • eigvals General eigenvalues only.
  • qr QR factorization returning QR.
  • svd Singular value decomposition returning SVD.
Conventions

All decompositions act on the last two tensor dimensions as matrix axes and preserve any leading dimensions as batch axes. The returned Tensor objects replace the matrix axes with one or more IndexSpace factors describing the decomposition bond dimensions.

For eigendecompositions, the last two dimensions must describe the same Hilbert space up to ray ordering so the matrix is square as a symbolic operator. For eigh and eigvalsh, that operator is additionally assumed to be Hermitian.

EigH

Bases: NamedTuple

Eigen-decomposition result container.

This is the shared return type of both eigh and eig.

Reconstruction

The returned tensors encode the matrix factorization on the last two axes.

  • For eigh, reconstruct the original Hermitian matrix by forming a diagonal matrix from result.eigenvalues, then evaluating the code expression V @ W @ V.h(...), where V = result.eigenvectors and W is that diagonal matrix. In conventional notation, \(A = V \Lambda V^\dagger\).
  • For eig, the returned tensors satisfy the eigenvalue equation \(A V = V\Lambda\). If the matrix is diagonalizable, then it can be reconstructed as \(V\Lambda V^{-1}\), where \(\Lambda\) is the diagonal matrix of eigenvalues. In code, this corresponds to products like A @ V, V @ W, and V @ W @ V.inv(). In conventional notation, \(A V = V \Lambda\) and \(A = V \Lambda V^{-1}\).

Attributes:

Name Type Description
eigenvalues Tensor

Eigenvalues tensor. Its dims keep the leading batch dimensions and replace the matrix axes with a single IndexSpace labeling the spectrum.

eigenvectors Tensor

Eigenvectors tensor. Its dims keep the leading batch dimensions, followed by the matrix row space and the spectral IndexSpace.

QR

Bases: NamedTuple

QR decomposition result container.

Reconstruction

Reconstruct the original matrix as \(Q R\) on the last two axes. In code, this is Q @ R.

In conventional notation, \(A = Q R\) and \(Q^\dagger Q = I\).

Attributes:

Name Type Description
Q Tensor

Orthogonal/unitary factor with dims equal to the leading batch dimensions followed by (row_dim, factor).

R Tensor

Upper-triangular factor with dims equal to the leading batch dimensions followed by (factor, col_dim).

SVD

Bases: NamedTuple

Singular-value decomposition result container.

Reconstruction

Reconstruct the original matrix as \(U\Sigma V^\dagger\). In code, this is U @ Sigma @ Vh, where Sigma is either result.S itself when values_as_matrix=True, or the diagonal matrix formed from the singular values when result.S is returned as a vector.

In conventional notation, \(A = U \Sigma V^\dagger\).

Attributes:

Name Type Description
U Tensor

Left singular vectors with dims determined by full_matrices.

S Tensor

Singular values, either as a vector or diagonal matrix depending on values_as_matrix.

Vh Tensor

Right singular vectors in conjugate-transposed form.

eigh

eigh(tensor: Tensor) -> EigH

Perform Hermitian eigendecomposition on the last two tensor dimensions.

This function applies torch.linalg.eigh to the matrix axes of a Tensor. The final two dimensions must span the same Hilbert space up to ray ordering so they can be interpreted as a Hermitian operator. Any leading dimensions are treated as batch dimensions and are preserved in both outputs.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor whose last two dimensions form Hermitian matrices.

required

Returns:

Type Description
EigH

EigH containing: - eigenvalues, whose dtype is the real dtype associated with the input and whose dims replace the matrix axes with one IndexSpace. - eigenvectors, whose dtype matches the input dtype and whose dims are the leading batch dims followed by (row_dim, spectrum).

Examples:

result = eigh(tensor)
eigenvalues = result.eigenvalues
eigenvectors = result.eigenvectors
Notes

torch.linalg.eigh is differentiable for Hermitian inputs, but the gradients can be ill-defined or unstable when eigenvalues are degenerate or nearly degenerate. If you use this in autograd, consider stabilizing the spectrum (e.g., with a small perturbation) or avoiding backpropagation through eigenvectors when bands are expected to merge.

The original matrix is recovered by forming a diagonal matrix from eigenvalues and evaluating \(V\Lambda V^\dagger\). In code, this is eigenvectors @ W @ eigenvectors.h(-2, -1).

Source code in src/qten/linalg/decompose.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def eigh(tensor: Tensor) -> EigH:
    r"""
    Perform Hermitian eigendecomposition on the last two tensor dimensions.

    This function applies [`torch.linalg.eigh`](https://pytorch.org/docs/stable/generated/torch.linalg.eigh.html)
    to the matrix axes of a [`Tensor`][qten.linalg.tensors.Tensor]. The final
    two dimensions must span the same Hilbert space up to ray ordering so they
    can be interpreted as a Hermitian operator. Any leading dimensions are
    treated as batch dimensions and are preserved in both outputs.

    Parameters
    ----------
    tensor : Tensor
        Input tensor whose last two dimensions form Hermitian matrices.

    Returns
    -------
    EigH
        [`EigH`][qten.linalg.decompose.EigH] containing:
        - `eigenvalues`, whose dtype is the real dtype associated with the
          input and whose dims replace the matrix axes with one
          [`IndexSpace`][qten.symbolics.state_space.IndexSpace].
        - `eigenvectors`, whose dtype matches the input dtype and whose dims
          are the leading batch dims followed by `(row_dim, spectrum)`.

    Examples
    --------
    ```python
    result = eigh(tensor)
    eigenvalues = result.eigenvalues
    eigenvectors = result.eigenvectors
    ```

    Notes
    -----
    `torch.linalg.eigh` is differentiable for Hermitian inputs, but the gradients
    can be ill-defined or unstable when eigenvalues are degenerate or nearly
    degenerate. If you use this in autograd, consider stabilizing the spectrum
    (e.g., with a small perturbation) or avoiding backpropagation through
    eigenvectors when bands are expected to merge.

    The original matrix is recovered by forming a diagonal matrix from
    `eigenvalues` and evaluating \(V\Lambda V^\dagger\). In code, this is
    `eigenvectors @ W @ eigenvectors.h(-2, -1)`.
    """
    _assert_eig_dims(tensor)

    dim0 = tensor.dims[-2]
    target = tensor.align(-1, dim0)  # Align column space to match the row space
    eigenvalues, eigenvectors = torch.linalg.eigh(target.data)

    spectrum = IndexSpace.linear(eigenvalues.shape[-1])

    eigvals = Tensor(
        data=eigenvalues,
        dims=target.dims[:-2] + (spectrum,),
    )
    eigvecs = Tensor(
        data=eigenvectors,
        dims=target.dims[:-2] + (dim0, spectrum),
    )

    return EigH(eigvals, eigvecs)

eigvalsh

eigvalsh(tensor: Tensor) -> Tensor

Compute Hermitian eigenvalues on the last two tensor dimensions.

This is the eigenvalues-only companion to eigh. The last two dimensions must span the same Hilbert space up to ray ordering and represent a Hermitian operator. Leading dimensions are treated as batch dimensions.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor whose last two dimensions form Hermitian matrices.

required

Returns:

Type Description
Tensor

Eigenvalues as a Tensor whose dtype matches the real dtype associated with the input and whose dims keep the leading batch dimensions while replacing the matrix axes with a single IndexSpace.

Examples:

values = eigvalsh(tensor)
Source code in src/qten/linalg/decompose.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def eigvalsh(tensor: Tensor) -> Tensor:
    """
    Compute Hermitian eigenvalues on the last two tensor dimensions.

    This is the eigenvalues-only companion to
    [`eigh`][qten.linalg.decompose.eigh]. The last two dimensions must span
    the same Hilbert space up to ray ordering and represent a Hermitian
    operator. Leading dimensions are treated as batch dimensions.

    Parameters
    ----------
    tensor : Tensor
        Input tensor whose last two dimensions form Hermitian matrices.

    Returns
    -------
    Tensor
        Eigenvalues as a [`Tensor`][qten.linalg.tensors.Tensor] whose dtype
        matches the real dtype associated with the input and whose dims keep
        the leading batch dimensions while replacing the matrix axes with a
        single [`IndexSpace`][qten.symbolics.state_space.IndexSpace].

    Examples
    --------
    ```python
    values = eigvalsh(tensor)
    ```
    """
    _assert_eig_dims(tensor)

    dim0 = tensor.dims[-2]
    target = tensor.align(-1, dim0)  # Align column space to match the row space
    eigenvalues = torch.linalg.eigvalsh(target.data)

    spectrum = IndexSpace.linear(eigenvalues.shape[-1])

    vals = Tensor(
        data=eigenvalues,
        dims=target.dims[:-2] + (spectrum,),
    )

    return vals

eig

eig(tensor: Tensor) -> EigH

Perform eigendecomposition on general square matrix axes.

This function applies torch.linalg.eig to the final two dimensions of a Tensor. The last two dimensions must span the same Hilbert space up to ray ordering so they can be interpreted as a square operator. Any leading dimensions are treated as batch dimensions and are preserved in both outputs.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor whose last two dimensions form square matrices.

required

Returns:

Type Description
EigH

EigH containing: - eigenvalues, whose dtype is the complex dtype associated with the input and whose dims replace the matrix axes with one IndexSpace. - eigenvectors, whose dtype matches that complex dtype and whose dims are the leading batch dims followed by (row_dim, spectrum).

Examples:

result = eig(tensor)
values = result.eigenvalues
vectors = result.eigenvectors
Notes

torch.linalg.eig does not guarantee any ordering of the eigenvalues. This function sorts eigenvalues lexicographically by (real, imag) and applies the same reordering to eigenvectors.

The returned tensors satisfy \(A V = V\Lambda\), where \(\Lambda\) is the diagonal matrix of eigenvalues. If the input matrix is diagonalizable, this gives the reconstruction \(A = V\Lambda V^{-1}\). In code, \(V\) is eigenvectors and \(\Lambda\) is the diagonal matrix built from eigenvalues.

Source code in src/qten/linalg/decompose.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def eig(tensor: Tensor) -> EigH:
    r"""
    Perform eigendecomposition on general square matrix axes.

    This function applies [`torch.linalg.eig`](https://pytorch.org/docs/stable/generated/torch.linalg.eig.html)
    to the final two dimensions of a
    [`Tensor`][qten.linalg.tensors.Tensor]. The last two dimensions must span
    the same Hilbert space up to ray ordering so they can be interpreted as a
    square operator. Any leading dimensions are treated as batch dimensions and
    are preserved in both outputs.

    Parameters
    ----------
    tensor : Tensor
        Input tensor whose last two dimensions form square matrices.

    Returns
    -------
    EigH
        [`EigH`][qten.linalg.decompose.EigH] containing:
        - `eigenvalues`, whose dtype is the complex dtype associated with the
          input and whose dims replace the matrix axes with one
          [`IndexSpace`][qten.symbolics.state_space.IndexSpace].
        - `eigenvectors`, whose dtype matches that complex dtype and whose dims
          are the leading batch dims followed by `(row_dim, spectrum)`.

    Examples
    --------
    ```python
    result = eig(tensor)
    values = result.eigenvalues
    vectors = result.eigenvectors
    ```

    Notes
    -----
    `torch.linalg.eig` does not guarantee any ordering of the eigenvalues. This
    function sorts eigenvalues lexicographically by `(real, imag)` and applies
    the same reordering to eigenvectors.

    The returned tensors satisfy \(A V = V\Lambda\), where \(\Lambda\) is the
    diagonal matrix of eigenvalues. If the input matrix is diagonalizable, this
    gives the reconstruction \(A = V\Lambda V^{-1}\). In code, \(V\) is
    `eigenvectors` and \(\Lambda\) is the diagonal matrix built from
    `eigenvalues`.
    """
    _assert_eig_dims(tensor)

    dim0 = tensor.dims[-2]
    target = tensor.align(-1, dim0)  # Align column space to match the row space
    eigenvalues, eigenvectors = torch.linalg.eig(target.data)
    eigenvalues, eigenvectors = _sort_eigenpairs(eigenvalues, eigenvectors)

    spectrum = IndexSpace.linear(eigenvalues.shape[-1])

    eigvals = Tensor(
        data=eigenvalues,
        dims=target.dims[:-2] + (spectrum,),
    )
    eigvecs = Tensor(
        data=eigenvectors,
        dims=target.dims[:-2] + (dim0, spectrum),
    )

    return EigH(eigvals, eigvecs)

eigvals

eigvals(tensor: Tensor) -> Tensor

Compute eigenvalues of general square matrix axes.

This is the eigenvalues-only companion to eig. The last two dimensions must span the same Hilbert space up to ray ordering and represent a square operator. Leading dimensions are treated as batch dimensions.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor whose last two dimensions form square matrices.

required

Returns:

Type Description
Tensor

Eigenvalues as a Tensor whose dtype matches the complex dtype associated with the input and whose dims keep the leading batch dimensions while replacing the matrix axes with a single IndexSpace.

Notes

torch.linalg.eigvals does not guarantee any ordering of the eigenvalues. This function sorts eigenvalues lexicographically by (real, imag).

Examples:

values = eigvals(tensor)
Source code in src/qten/linalg/decompose.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def eigvals(tensor: Tensor) -> Tensor:
    """
    Compute eigenvalues of general square matrix axes.

    This is the eigenvalues-only companion to
    [`eig`][qten.linalg.decompose.eig]. The last two dimensions must span the
    same Hilbert space up to ray ordering and represent a square operator.
    Leading dimensions are treated as batch dimensions.

    Parameters
    ----------
    tensor : Tensor
        Input tensor whose last two dimensions form square matrices.

    Returns
    -------
    Tensor
        Eigenvalues as a [`Tensor`][qten.linalg.tensors.Tensor] whose dtype
        matches the complex dtype associated with the input and whose dims keep
        the leading batch dimensions while replacing the matrix axes with a
        single [`IndexSpace`][qten.symbolics.state_space.IndexSpace].

    Notes
    -----
    `torch.linalg.eigvals` does not guarantee any ordering of the eigenvalues.
    This function sorts eigenvalues lexicographically by `(real, imag)`.

    Examples
    --------
    ```python
    values = eigvals(tensor)
    ```
    """
    _assert_eig_dims(tensor)

    dim0 = tensor.dims[-2]
    target = tensor.align(-1, dim0)  # Align column space to match the row space
    eigenvalues = torch.linalg.eigvals(target.data)
    eigenvalues, _ = _sort_eigenpairs(eigenvalues)

    spectrum = IndexSpace.linear(eigenvalues.shape[-1])

    vals = Tensor(
        data=eigenvalues,
        dims=target.dims[:-2] + (spectrum,),
    )

    return vals

qr

qr(tensor: Tensor) -> QR

Perform reduced QR decomposition on the last two tensor dimensions.

This function applies torch.linalg.qr with mode="reduced" to the matrix axes of the input tensor. The last two dimensions may be rectangular. Any leading dimensions are treated as batch dimensions and are preserved in both outputs.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor whose last two dimensions form matrices.

required

Returns:

Type Description
QR

QR containing: - Q, a Tensor with orthonormal columns and dims (..., row_dim, factor). - R, an upper-triangular Tensor with dims (..., factor, col_dim).

Examples:

result = qr(tensor)
q = result.Q
r = result.R
Notes

The shared factor axis is represented by an IndexSpace whose size equals the reduced QR bond dimension. The original matrix is recovered as \(Q R\), via Q @ R in code.

Source code in src/qten/linalg/decompose.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
def qr(tensor: Tensor) -> QR:
    r"""
    Perform reduced QR decomposition on the last two tensor dimensions.

    This function applies [`torch.linalg.qr`](https://pytorch.org/docs/stable/generated/torch.linalg.qr.html)
    with `mode="reduced"` to the matrix axes of the input tensor. The last two
    dimensions may be rectangular. Any leading dimensions are treated as batch
    dimensions and are preserved in both outputs.

    Parameters
    ----------
    tensor : Tensor
        Input tensor whose last two dimensions form matrices.

    Returns
    -------
    QR
        [`QR`][qten.linalg.decompose.QR] containing:
        - `Q`, a [`Tensor`][qten.linalg.tensors.Tensor] with orthonormal
          columns and dims `(..., row_dim, factor)`.
        - `R`, an upper-triangular
          [`Tensor`][qten.linalg.tensors.Tensor] with dims
          `(..., factor, col_dim)`.

    Examples
    --------
    ```python
    result = qr(tensor)
    q = result.Q
    r = result.R
    ```

    Notes
    -----
    The shared `factor` axis is represented by an
    [`IndexSpace`][qten.symbolics.state_space.IndexSpace] whose size equals the
    reduced QR bond dimension. The original matrix is recovered as \(Q R\), via
    `Q @ R` in code.
    """
    if tensor.rank() < 2:
        raise ValueError(
            "Input tensor must have at least two dimensions for matrix decomposition."
        )

    row_dim = tensor.dims[-2]
    col_dim = tensor.dims[-1]

    q_data, r_data = torch.linalg.qr(tensor.data, mode="reduced")
    spectral_dim = IndexSpace.linear(q_data.shape[-1])

    q = Tensor(
        data=q_data,
        dims=tensor.dims[:-2] + (row_dim, spectral_dim),
    )
    r = Tensor(
        data=r_data,
        dims=tensor.dims[:-2] + (spectral_dim, col_dim),
    )

    return QR(q, r)

svd

svd(
    tensor: Tensor,
    values_as_matrix: bool = False,
    full_matrices: bool = False,
) -> SVD

Perform singular value decomposition on the last two tensor dimensions.

This function applies torch.linalg.svd to the matrix axes of the input tensor and returns symbolic dimensions that distinguish reduced and full factorizations. The last two dimensions may be rectangular. Any leading dimensions are treated as batch dimensions and are preserved in all outputs.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor whose last two dimensions form matrices.

required
values_as_matrix bool

If True, return singular values as an explicit diagonal matrix tensor. If False, return them as a vector on a single spectral axis.

`False`
full_matrices bool

If True, compute full-sized U and Vh. If False, compute the reduced SVD.

`False`

Returns:

Type Description
SVD

SVD containing: - U, with dims (..., row_dim, factor) for reduced SVD or (..., row_dim, left_factor) for full SVD. - S, with dims (..., factor) by default, (..., factor, factor) when values_as_matrix=True in reduced mode, or (..., left_factor, right_factor) in full matrix form. - Vh, with dims (..., factor, col_dim) for reduced SVD or (..., right_factor, col_dim) for full SVD.

Examples:

result = svd(tensor)
u = result.U
s = result.S
vh = result.Vh
Notes

In reduced mode, factor is the shared singular-value IndexSpace. In full mode, left_factor and right_factor are sized to the full row and column spaces of the input matrix axes. The original matrix is recovered as \(U\Sigma V^\dagger\), using U @ Sigma @ Vh in code. Here Sigma is either the returned S tensor (values_as_matrix=True) or the diagonal matrix formed from the returned singular-value vector.

Source code in src/qten/linalg/decompose.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
def svd(
    tensor: Tensor,
    values_as_matrix: bool = False,
    full_matrices: bool = False,
) -> SVD:
    r"""
    Perform singular value decomposition on the last two tensor dimensions.

    This function applies [`torch.linalg.svd`](https://pytorch.org/docs/stable/generated/torch.linalg.svd.html)
    to the matrix axes of the input tensor and returns symbolic dimensions that
    distinguish reduced and full factorizations. The last two dimensions may be
    rectangular. Any leading dimensions are treated as batch dimensions and are
    preserved in all outputs.

    Parameters
    ----------
    tensor : Tensor
        Input tensor whose last two dimensions form matrices.
    values_as_matrix : bool, default `False`
        If `True`, return singular values as an explicit diagonal matrix
        tensor. If `False`, return them as a vector on a single spectral axis.
    full_matrices : bool, default `False`
        If `True`, compute full-sized `U` and `Vh`. If `False`, compute the
        reduced SVD.

    Returns
    -------
    SVD
        [`SVD`][qten.linalg.decompose.SVD] containing:
        - `U`, with dims `(..., row_dim, factor)` for reduced SVD or
          `(..., row_dim, left_factor)` for full SVD.
        - `S`, with dims `(..., factor)` by default, `(..., factor, factor)`
          when `values_as_matrix=True` in reduced mode, or
          `(..., left_factor, right_factor)` in full matrix form.
        - `Vh`, with dims `(..., factor, col_dim)` for reduced SVD or
          `(..., right_factor, col_dim)` for full SVD.

    Examples
    --------
    ```python
    result = svd(tensor)
    u = result.U
    s = result.S
    vh = result.Vh
    ```

    Notes
    -----
    In reduced mode, `factor` is the shared singular-value
    [`IndexSpace`][qten.symbolics.state_space.IndexSpace]. In full mode,
    `left_factor` and `right_factor` are sized to the full row and column
    spaces of the input matrix axes. The original matrix is recovered as
    \(U\Sigma V^\dagger\), using `U @ Sigma @ Vh` in code. Here `Sigma` is
    either the returned `S` tensor (`values_as_matrix=True`) or the diagonal
    matrix formed from the returned singular-value vector.
    """
    if tensor.rank() < 2:
        raise ValueError(
            "Input tensor must have at least two dimensions for matrix decomposition."
        )

    row_dim = tensor.dims[-2]
    col_dim = tensor.dims[-1]

    u_data, s_data, vh_data = torch.linalg.svd(tensor.data, full_matrices=full_matrices)

    factor = IndexSpace.linear(s_data.shape[-1])

    if full_matrices:
        left_factor = IndexSpace.linear(row_dim.dim)
        right_factor = IndexSpace.linear(col_dim.dim)
        u = Tensor(
            data=u_data,
            dims=tensor.dims[:-2] + (row_dim, left_factor),
        )
    else:
        u = Tensor(
            data=u_data,
            dims=tensor.dims[:-2] + (row_dim, factor),
        )
    if values_as_matrix:
        if full_matrices:
            k = s_data.shape[-1]
            s_mat = torch.zeros(
                *s_data.shape[:-1],
                left_factor.dim,
                right_factor.dim,
                dtype=s_data.dtype,
                device=s_data.device,
            )
            diag = torch.diag_embed(s_data)
            s_mat[..., :k, :k] = diag
            s = Tensor(
                data=s_mat,
                dims=tensor.dims[:-2] + (left_factor, right_factor),
            )
        else:
            s_mat = torch.diag_embed(s_data)
            s = Tensor(
                data=s_mat,
                dims=tensor.dims[:-2] + (factor, factor),
            )
    else:
        s = Tensor(
            data=s_data,
            dims=tensor.dims[:-2] + (factor,),
        )
    if full_matrices:
        vh = Tensor(
            data=vh_data,
            dims=tensor.dims[:-2] + (right_factor, col_dim),
        )
    else:
        vh = Tensor(
            data=vh_data,
            dims=tensor.dims[:-2] + (factor, col_dim),
        )

    return SVD(u, s, vh)