Skip to content

qten.bands

Module reference for qten.bands.

bands

Band-structure helpers for momentum-resolved QTen tensors.

This module provides utilities for transforming, folding, unfolding, filling, and selecting bands represented as Tensor objects. The common convention is that a band tensor has dimensions (MomentumSpace, HilbertSpace, HilbertSpace): the MomentumSpace axis indexes crystal momenta and the two HilbertSpace axes form the Hamiltonian or operator matrix at each momentum.

Mathematical convention

A band tensor represents a family of matrices indexed by crystal momentum: \(H : k \mapsto H(k)\), with \(H(k)_{ab} = \langle a | H(k) | b \rangle\).

In code this is stored as a rank-3 Tensor with dims (K, B_left, B_right), where K is a MomentumSpace and the two Hilbert-space axes provide the row and column basis labels for each matrix block.

Geometry transformations act on both parts of this object: \(k \mapsto k'\) and \(H(k) \mapsto U(k)\,H(k)\,U(k)^\dagger\).

where the \(k\)-dependent change-of-basis matrix \(U(k)\) is assembled from symbolic Hilbert-space relabeling and finite Fourier transforms.

Repository usage

The functions here sit between geometry, symbolic Hilbert-space labels, and linear algebra. Geometry objects provide real and reciprocal lattice structure, symbolic state spaces label tensor axes, and linear algebra routines diagonalize the momentum-sector matrices when filling or selecting bands.

interpolate_path

interpolate_path(
    recip: ReciprocalLattice,
    waypoints: Sequence[Union[Tuple[float, ...], str]],
    n_points: int = 100,
    labels: Optional[Sequence[str]] = None,
    points: Optional[Dict[str, Tuple[float, ...]]] = None,
) -> BzPath

Build a sampled Brillouin-zone path in a reciprocal lattice.

This is a backward-compatible wrapper around interpolate_reciprocal_path. New code may call that symbolic helper directly.

Parameters:

Name Type Description Default
recip ReciprocalLattice

Reciprocal lattice in which waypoint coordinates are interpreted.

required
waypoints Sequence[Union[Tuple[float, ...], str]]

Sequence of explicit fractional coordinates or names looked up in points. For example, [(0.0, 0.0), (0.5, 0.0), (0.5, 0.5)] samples a path through three explicit two-dimensional reciprocal coordinates, while ["G", "X", "M"] resolves coordinates from the points mapping.

required
n_points int

Number of samples used along the full interpolated path.

100
labels Sequence[str] | None

Optional display labels for the waypoint ticks. For example, ["Γ", "X", "M"] can label a path whose named inputs are ["G", "X", "M"].

None
points Dict[str, Tuple[float, ...]] | None

Optional mapping from waypoint names to fractional reciprocal coordinates. For example, {"G": (0.0, 0.0), "X": (0.5, 0.0), "M": (0.5, 0.5)}.

None

Returns:

Type Description
BzPath

Sampled Brillouin-zone path with momentum space, waypoint labels, and path-order metadata.

Raises:

Type Description
ValueError

If fewer than two waypoints are supplied, if a named waypoint is not present in points, if waypoint coordinate dimensions do not match recip.dim, if n_points is too small for the number of waypoints, if all waypoints are identical, or if labels does not match the number of waypoints.

See Also

interpolate_reciprocal_path(recip, waypoints, n_points, labels, points) Canonical implementation used by this compatibility wrapper.

Examples:

path = interpolate_path(
    recip,
    waypoints=[(0.0, 0.0), (0.5, 0.0), (0.5, 0.5)],
    labels=["Γ", "X", "M"],
)
path = interpolate_path(
    recip,
    waypoints=["G", "X", "M"],
    labels=["Γ", "X", "M"],
    points={"G": (0.0, 0.0), "X": (0.5, 0.0), "M": (0.5, 0.5)},
)
Source code in src/qten/bands.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def interpolate_path(
    recip: ReciprocalLattice,
    waypoints: Sequence[Union[Tuple[float, ...], str]],
    n_points: int = 100,
    labels: Optional[Sequence[str]] = None,
    points: Optional[Dict[str, Tuple[float, ...]]] = None,
) -> BzPath:
    """
    Build a sampled Brillouin-zone path in a reciprocal lattice.

    This is a backward-compatible wrapper around
    [`interpolate_reciprocal_path`][qten.symbolics.ops.interpolate_reciprocal_path].
    New code may call that symbolic helper directly.

    Parameters
    ----------
    recip : ReciprocalLattice
        Reciprocal lattice in which waypoint coordinates are interpreted.
    waypoints : Sequence[Union[Tuple[float, ...], str]]
        Sequence of explicit fractional coordinates or names looked up in
        `points`.
        For example, `[(0.0, 0.0), (0.5, 0.0), (0.5, 0.5)]`
        samples a path through three explicit two-dimensional reciprocal
        coordinates, while `["G", "X", "M"]` resolves coordinates from the
        `points` mapping.
    n_points : int
        Number of samples used along the full interpolated path.
    labels : Sequence[str] | None
        Optional display labels for the waypoint ticks.
        For example, `["Γ", "X", "M"]` can label a path whose named inputs are
        `["G", "X", "M"]`.
    points : Dict[str, Tuple[float, ...]] | None
        Optional mapping from waypoint names to fractional reciprocal
        coordinates. For example,
        `{"G": (0.0, 0.0), "X": (0.5, 0.0), "M": (0.5, 0.5)}`.

    Returns
    -------
    BzPath
        Sampled Brillouin-zone path with momentum space, waypoint labels, and
        path-order metadata.

    Raises
    ------
    ValueError
        If fewer than two waypoints are supplied, if a named waypoint is not
        present in `points`, if waypoint coordinate dimensions do not match
        `recip.dim`, if `n_points` is too small for the number of waypoints, if
        all waypoints are identical, or if `labels` does not match the number
        of waypoints.

    See Also
    --------
    [`interpolate_reciprocal_path(recip, waypoints, n_points, labels, points)`][qten.symbolics.ops.interpolate_reciprocal_path]
        Canonical implementation used by this compatibility wrapper.

    Examples
    --------
    ```python
    path = interpolate_path(
        recip,
        waypoints=[(0.0, 0.0), (0.5, 0.0), (0.5, 0.5)],
        labels=["Γ", "X", "M"],
    )
    ```

    ```python
    path = interpolate_path(
        recip,
        waypoints=["G", "X", "M"],
        labels=["Γ", "X", "M"],
        points={"G": (0.0, 0.0), "X": (0.5, 0.0), "M": (0.5, 0.5)},
    )
    ```
    """
    return interpolate_reciprocal_path(
        recip=recip,
        waypoints=waypoints,
        n_points=n_points,
        labels=labels,
        points=points,
    )

bandtransform

bandtransform(t: Opr, tensor: Tensor) -> Tensor

Apply a basis transform to a momentum-resolved operator tensor.

The expected tensor shape is (K, B_left, B_right) where K is a MomentumSpace and B_left, B_right are HilbertSpace axes. This function applies the operator-induced basis transform on both Hilbert-space legs of the band tensor.

For each transformed side, a k-dependent matrix is built from the action of t on the Hilbert-space basis and Fourier transforms that connect Bloch and real-space sectors.

Mathematical action

Let \(B\) be the input Hilbert-space basis and \(tB\) the transformed basis. After wrapping transformed sites back to the home unit cell, the finite Fourier transform contributes a momentum-dependent phase. The resulting basis-change matrix is denoted \(U_t(k)\). The transformed band block is \(H'(t k) = U_t(k)\,H(k)\,U_t(k)^\dagger\). In code, left_fourier and right_fourier are the two \(U_t(k)\)-style maps, and the products are left_fourier @ tensor and tensor @ right_fourier.h(-2, -1).

Momentum handling
  • The action on Momentum is treated as a relabeling/permutation of sectors.
  • The output tensor carries the transformed momentum axis mapped_kspace = {t @ k | k in kspace}.
  • Each output k-block is populated from the preimage source block before the Hilbert-space conjugation is applied.
Notes

This function accepts a general Opr, but not every Opr is valid here. In practice, t must act coherently across the real-space and momentum-space labels carried by the tensor:

  • t @ k must be defined for each Momentum in the first tensor axis.
  • t @ psi must be defined for each U1Basis in the Hilbert-space axes, in particular for the Offset irrep stored inside each basis state.
  • The Hilbert-space action and momentum action must be dual-compatible, so that the Fourier transform remains consistent after applying t.
  • After applying FuncOpr(Offset, Offset.fractional), the transformed Hilbert space must have the same rays as the original one; otherwise the transformed basis does not close on the input band space and this function raises ValueError.

Operators that only act on abstract U1Basis values or only on Momentum values are not sufficient. The operator must provide matching actions on site offsets and crystal momentum.

Parameters:

Name Type Description Default
t Opr

Operator to apply. It must satisfy the compatibility conditions described in the notes below.

required
tensor Tensor

Momentum-space tensor with dims (MomentumSpace, HilbertSpace, HilbertSpace).

required

Returns:

Type Description
Tensor

Transformed tensor with a transformed MomentumSpace axis and HilbertSpace matrix axes.

Raises:

Type Description
ValueError

If tensor is not rank 3 with a MomentumSpace axis and two HilbertSpace axes. Also raised if a Hilbert-space side is not closed under the action of t.

Source code in src/qten/bands.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
def bandtransform(
    t: Opr,
    tensor: Tensor,
) -> Tensor:
    r"""
    Apply a basis transform to a momentum-resolved operator tensor.

    The expected tensor shape is `(K, B_left, B_right)` where `K` is a
    [`MomentumSpace`][qten.symbolics.state_space.MomentumSpace] and
    `B_left`, `B_right` are
    [`HilbertSpace`][qten.symbolics.hilbert_space.HilbertSpace] axes. This
    function applies the operator-induced basis transform on both
    Hilbert-space legs of the band tensor.

    For each transformed side, a k-dependent matrix is built from the action of
    `t` on the Hilbert-space basis and Fourier transforms that connect Bloch and
    real-space sectors.

    Mathematical action
    -------------------
    Let \(B\) be the input Hilbert-space basis and \(tB\) the transformed basis.
    After wrapping transformed sites back to the home unit cell, the finite
    Fourier transform contributes a momentum-dependent phase. The resulting
    basis-change matrix is denoted \(U_t(k)\). The transformed band block is
    \(H'(t k) = U_t(k)\,H(k)\,U_t(k)^\dagger\). In code, `left_fourier` and `right_fourier` are the two \(U_t(k)\)-style
    maps, and the products are `left_fourier @ tensor` and
    `tensor @ right_fourier.h(-2, -1)`.

    Momentum handling
    -----------------
    - The action on [`Momentum`][qten.geometries.spatials.Momentum] is treated as a relabeling/permutation of sectors.
    - The output tensor carries the transformed momentum axis
      `mapped_kspace = {t @ k | k in kspace}`.
    - Each output k-block is populated from the preimage source block before
      the Hilbert-space conjugation is applied.

    Notes
    -----
    This function accepts a general [`Opr`][qten.symbolics.hilbert_space.Opr], but not every [`Opr`][qten.symbolics.hilbert_space.Opr] is valid here.
    In practice, `t` must act coherently across the real-space and
    momentum-space labels carried by the tensor:

    - `t @ k` must be defined for each [`Momentum`][qten.geometries.spatials.Momentum] in the first tensor axis.
    - `t @ psi` must be defined for each [`U1Basis`][qten.symbolics.hilbert_space.U1Basis] in the Hilbert-space axes,
      in particular for the [`Offset`][qten.geometries.spatials.Offset] irrep stored inside each basis state.
    - The Hilbert-space action and momentum action must be dual-compatible, so
      that the Fourier transform remains consistent after applying `t`.
    - After applying [`FuncOpr(Offset, Offset.fractional)`][qten.symbolics.hilbert_space.FuncOpr], the transformed
      Hilbert space must have the same rays as the original one; otherwise the
      transformed basis does not close on the input band space and this
      function raises `ValueError`.

    Operators that only act on abstract [`U1Basis`][qten.symbolics.hilbert_space.U1Basis] values or only on [`Momentum`][qten.geometries.spatials.Momentum]
    values are not sufficient. The operator must provide matching actions on
    site offsets and crystal momentum.

    Parameters
    ----------
    t : Opr
        Operator to apply. It must satisfy the compatibility conditions
        described in the notes below.
    tensor : Tensor
        Momentum-space tensor with dims
        `(MomentumSpace, HilbertSpace, HilbertSpace)`.

    Returns
    -------
    Tensor
        Transformed tensor with a transformed
        [`MomentumSpace`][qten.symbolics.state_space.MomentumSpace] axis and
        [`HilbertSpace`][qten.symbolics.hilbert_space.HilbertSpace] matrix axes.

    Raises
    ------
    ValueError
        If `tensor` is not rank 3 with a
        [`MomentumSpace`][qten.symbolics.state_space.MomentumSpace] axis and
        two [`HilbertSpace`][qten.symbolics.hilbert_space.HilbertSpace] axes.
        Also raised if a Hilbert-space side is not closed under the action of
        `t`.
    """
    if not len(tensor.dims) == 3:
        raise ValueError("Input tensor must have exactly 3 dimensions.")
    if not isinstance(tensor.dims[0], MomentumSpace):
        raise ValueError("First dimension of tensor must be a MomentumSpace.")
    if not isinstance(tensor.dims[1], HilbertSpace):
        raise ValueError("Second dimension of tensor must be a HilbertSpace.")
    if not isinstance(tensor.dims[2], HilbertSpace):
        raise ValueError("Third dimension of tensor must be a HilbertSpace.")

    kspace: MomentumSpace = cast(MomentumSpace, tensor.dims[0])
    transform_cache: Dict[HilbertSpace, Tensor] = {}

    mapped_kspace = _momentum_map(kspace, lambda k: cast(Momentum, t @ k))

    def build_transform(space: HilbertSpace) -> Tensor:
        cached = transform_cache.get(space)
        if cached is not None:
            return cached

        fractional = FuncOpr(Offset, Offset.fractional)
        raw_space = cast(HilbertSpace, t @ space)
        new_space = cast(HilbertSpace, fractional @ raw_space)
        # The transformation will distort the unit-cell of the Hilbert space,
        # we will use fractional to return it to the original unit-cell.
        if not space.same_rays(new_space):
            raise ValueError(
                f"Hilbert space {space} is not closed under the transform {t}!"
            )
        # `raw_space` keeps the transformed positions before wrapping them back
        # into the home cell; `new_space` is the corresponding wrapped basis.
        # Their difference is the lattice translation whose Bloch phase is
        # encoded by the Fourier transform below.
        transformed_fourier = fourier_transform(
            mapped_kspace, new_space, raw_space, device=tensor.device
        ).replace_dim(2, space)  # (K, B, B')
        # This is the home-cell basis map analogous to the Julia
        # `homefocktransform`: it relabels the wrapped transformed basis back
        # onto the original Hilbert-space labels.
        home_transform = cast(
            Tensor, space.cross_gram(new_space, device=tensor.device)
        ).replace_dim(1, new_space)
        transform = home_transform @ transformed_fourier  # (K, B, B)
        transform_cache[space] = transform
        return transform

    tensor = tensor.replace_dim(0, mapped_kspace)

    left_fourier = build_transform(cast(HilbertSpace, tensor.dims[1]))  # (K, B, B)
    left_fourier = left_fourier.replace_dim(0, mapped_kspace)  # (K, B, B)
    tensor = cast(Tensor, (left_fourier @ tensor))  # (K, B, B)

    right_fourier = build_transform(cast(HilbertSpace, tensor.dims[2]))  # (K, B, B)
    right_fourier = right_fourier.replace_dim(0, mapped_kspace)  # (K, B, B)
    tensor = cast(Tensor, (tensor @ right_fourier.h(-2, -1)))  # (K, B, B)

    return tensor

bandfold

bandfold(
    transform: BasisTransform, tensor: Tensor
) -> Tensor

Fold a momentum-resolved band tensor into the Brillouin zone of a transformed lattice basis.

The input tensor is expected to have dimensions (MomentumSpace, HilbertSpace, HilbertSpace). The basis transformation is applied to the direct lattice underlying the MomentumSpace axis, which produces a new Brillouin zone and a corresponding momentum remapping. One HilbertSpace leg is enlarged to match the transformed unit cell, a Fourier-space change of basis is applied, and the momentum sectors are then gathered into the new momentum grid.

Mathematical action

A forward basis transform coarsens the direct lattice basis, so the reciprocal Brillouin zone shrinks and multiple old momenta fold onto one new momentum sector. If \(F(k)\) is the Fourier map from the old cell basis into the enlarged transformed-cell basis, each block is transformed as \(H_{\mathrm{fold}}(k') \mathrel{+}= F(k)^\dagger H(k) F(k)\), with \(k' = \mathrm{fold}(k)\). The code-level implementation uses fh @ tensor @ f for the block transform and index_add(0, k_indices, transformed) to accumulate old sectors into the folded momentum axis.

Parameters:

Name Type Description Default
transform BasisTransform

Basis change applied to the direct lattice associated with the momentum axis.

required
tensor Tensor

Rank-3 tensor with dimensions (MomentumSpace, HilbertSpace, HilbertSpace).

required

Returns:

Type Description
Tensor

Folded tensor on the transformed MomentumSpace grid with transformed HilbertSpace matrix axes.

Raises:

Type Description
ValueError

If the tensor is not rank-3, if the momentum space is empty, or if the momentum axis does not belong to a single Brillouin zone.

TypeError

If the momentum axis is not a MomentumSpace, if its underlying space is not a ReciprocalLattice, or if the selected Hilbert-space leg is not a HilbertSpace.

Source code in src/qten/bands.py
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
def bandfold(
    transform: BasisTransform,
    tensor: Tensor,
) -> Tensor:
    r"""
    Fold a momentum-resolved band tensor into the Brillouin zone of a
    transformed lattice basis.

    The input tensor is expected to have dimensions
    `(MomentumSpace, HilbertSpace, HilbertSpace)`. The basis transformation is
    applied to the direct lattice underlying the
    [`MomentumSpace`][qten.symbolics.state_space.MomentumSpace] axis, which
    produces a new Brillouin zone and a corresponding momentum remapping. One
    [`HilbertSpace`][qten.symbolics.hilbert_space.HilbertSpace] leg is enlarged
    to match the transformed unit cell, a Fourier-space change of basis is
    applied, and the momentum sectors are then gathered into the new momentum
    grid.

    Mathematical action
    -------------------
    A forward basis transform coarsens the direct lattice basis, so the
    reciprocal Brillouin zone shrinks and multiple old momenta fold onto one
    new momentum sector. If \(F(k)\) is the Fourier map from the old cell basis
    into the enlarged transformed-cell basis, each block is transformed as
    \(H_{\mathrm{fold}}(k') \mathrel{+}= F(k)^\dagger H(k) F(k)\), with
    \(k' = \mathrm{fold}(k)\). The code-level implementation uses `fh @ tensor @ f` for the block
    transform and `index_add(0, k_indices, transformed)` to accumulate old
    sectors into the folded momentum axis.

    Parameters
    ----------
    transform : BasisTransform
        Basis change applied to the direct lattice associated with the momentum
        axis.
    tensor : Tensor
        Rank-3 tensor with dimensions
        `(MomentumSpace, HilbertSpace, HilbertSpace)`.

    Returns
    -------
    Tensor
        Folded tensor on the transformed
        [`MomentumSpace`][qten.symbolics.state_space.MomentumSpace] grid with
        transformed [`HilbertSpace`][qten.symbolics.hilbert_space.HilbertSpace]
        matrix axes.

    Raises
    ------
    ValueError
        If the tensor is not rank-3, if the momentum space is empty, or if the
        momentum axis does not belong to a single Brillouin zone.
    TypeError
        If the momentum axis is not a
        [`MomentumSpace`][qten.symbolics.state_space.MomentumSpace], if its
        underlying space is not a
        [`ReciprocalLattice`][qten.geometries.spatials.ReciprocalLattice], or
        if the selected Hilbert-space leg is not a
        [`HilbertSpace`][qten.symbolics.hilbert_space.HilbertSpace].
    """
    # 1. Parse inputs
    if not tensor.rank() == 3:
        raise ValueError(
            f"Input tensor must be of rank 3, but has rank {tensor.rank()}"
        )
    if not isinstance(tensor.dims[0], MomentumSpace):
        raise TypeError(
            "The first dimension of the tensor must be a MomentumSpace, "
            f"but is of type {type(tensor.dims[0])}"
        )
    k_space = cast(MomentumSpace, tensor.dims[0])
    if not k_space.elements():
        raise ValueError("MomentumSpace is empty")
    lattice_set = set(map(lambda k: k.space, k_space))
    if len(lattice_set) != 1:
        raise ValueError("Invalid BZ")
    reciprocal_lattice = lattice_set.pop()
    if not isinstance(reciprocal_lattice, ReciprocalLattice):
        raise TypeError(
            f"Space of momentum should be ReciprocalLattice, but got {type(reciprocal_lattice)}"
        )
    reciprocal_lattice = cast(ReciprocalLattice, reciprocal_lattice)
    lattice = reciprocal_lattice.dual

    # 2. Apply the transformation
    scaled_lattice = transform(lattice)

    # 3. Create new transformed spaces
    scaled_reciprocal_lattice = scaled_lattice.dual
    transformed_unit_cell = tuple(
        sorted(scaled_lattice.unit_cell.values(), key=lambda x: tuple(x.rep))
    )
    # Keep a rebased copy for the current Fourier/matching logic, but return
    # the transformed offsets on the output Hilbert-space labels.
    enlarge_unit_cell = tuple(r.rebase(lattice) for r in transformed_unit_cell)

    # Follow the existing "both" branch behavior by rebuilding the right leg.
    target_space = tensor.dims[-1]
    if not isinstance(target_space, HilbertSpace):
        raise TypeError(
            f"The last dimension must be a HilbertSpace, but got {type(target_space)}"
        )
    rebased_hilbert = HilbertSpace.new(
        cast(U1Basis, target_space.lookup({Offset: r.fractional()})).replace(r)
        for r in enlarge_unit_cell
    )
    transformed_hilbert = HilbertSpace.new(
        cast(U1Basis, target_space.lookup({Offset: r_lookup.fractional()})).replace(
            r_out
        )
        for r_lookup, r_out in zip(enlarge_unit_cell, transformed_unit_cell)
    )
    # # Transform both sides
    f = fourier_transform(k_space, target_space, rebased_hilbert, device=tensor.device)
    vratio = np.sqrt(len(enlarge_unit_cell) / len(lattice.unit_cell))
    f = f / vratio
    fh = f.h(-2, -1)  # (K, B', B)
    transformed = fh @ tensor @ f  # (K, B', B')

    # k-mapping: batch-compute which new-BZ slot each old k-point folds into.
    new_k_space = brillouin_zone(scaled_reciprocal_lattice)

    precision = get_precision_config()
    old_basis_np = np.array(reciprocal_lattice.basis.evalf(), dtype=precision.np_float)
    new_basis_np = np.array(
        scaled_reciprocal_lattice.basis.evalf(), dtype=precision.np_float
    )
    M_rebase = np.linalg.solve(new_basis_np, old_basis_np)

    k_indices = _momentum_match_indices(
        k_space, new_k_space, M_rebase, device=tensor.device
    )

    transformed = (
        zeros((new_k_space, rebased_hilbert, rebased_hilbert), device=tensor.device)
        .astype(transformed.data.dtype)
        .index_add(0, k_indices, transformed)
    )
    for dim in (1, 2):
        if transformed.dims[dim] == rebased_hilbert:
            transformed = transformed.replace_dim(dim, transformed_hilbert)
    return transformed

bandunfold

bandunfold(
    inverse_transform: InverseBasisTransform, tensor: Tensor
) -> Tensor

Unfold a folded momentum-resolved band tensor using an inverse basis transform.

The input is expected to have dimensions (MomentumSpace, HilbertSpace, HilbertSpace) where the MomentumSpace axis lives on a transformed (folded) Brillouin zone. The inverse transform maps that folded lattice back to the primitive one and recovers dimensions (K_primitive, B_primitive, B_primitive).

Mathematical action

Unfolding routes each primitive momentum \(k\) to its parent folded momentum \(\bar{k}\), gathers \(H_{\mathrm{fold}}(\bar{k})\), and then projects it back to the primitive-cell basis with a Fourier map \(F(k)\): \(H_{\mathrm{unfold}}(k) = F(k)\,H_{\mathrm{fold}}(\bar{k})\,F(k)^\dagger\). In code, the parent-sector lookup is tensor.data[k_indices.data], and the final basis projection is f @ gathered @ f.h(-2, -1).

Parameters:

Name Type Description Default
inverse_transform InverseBasisTransform

Inverse basis transform that maps the folded direct lattice back to the primitive lattice.

required
tensor Tensor

Rank-3 folded band tensor with dimensions (MomentumSpace, HilbertSpace, HilbertSpace).

required

Returns:

Type Description
Tensor

Unfolded tensor on the primitive Brillouin-zone MomentumSpace grid with primitive HilbertSpace matrix axes.

Raises:

Type Description
TypeError

If inverse_transform is not an InverseBasisTransform, if the tensor axes do not have the required symbolic space types, or if the momentum axis is not backed by a ReciprocalLattice.

ValueError

If tensor is not rank 3, if the momentum space is empty, or if the momentum axis mixes incompatible reciprocal lattices.

Source code in src/qten/bands.py
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
def bandunfold(
    inverse_transform: InverseBasisTransform,
    tensor: Tensor,
) -> Tensor:
    r"""
    Unfold a folded momentum-resolved band tensor using an inverse basis transform.

    The input is expected to have dimensions `(MomentumSpace, HilbertSpace,
    HilbertSpace)` where the
    [`MomentumSpace`][qten.symbolics.state_space.MomentumSpace] axis lives on a
    transformed (folded) Brillouin zone. The inverse transform maps that folded
    lattice back to the primitive one and recovers dimensions
    `(K_primitive, B_primitive, B_primitive)`.

    Mathematical action
    -------------------
    Unfolding routes each primitive momentum \(k\) to its parent folded
    momentum \(\bar{k}\), gathers \(H_{\mathrm{fold}}(\bar{k})\), and then
    projects it back to the primitive-cell basis with a Fourier map \(F(k)\):
    \(H_{\mathrm{unfold}}(k)
    = F(k)\,H_{\mathrm{fold}}(\bar{k})\,F(k)^\dagger\). In code, the parent-sector lookup is `tensor.data[k_indices.data]`, and the
    final basis projection is `f @ gathered @ f.h(-2, -1)`.

    Parameters
    ----------
    inverse_transform : InverseBasisTransform
        Inverse basis transform that maps the folded direct lattice back to the
        primitive lattice.
    tensor : Tensor
        Rank-3 folded band tensor with dimensions
        `(MomentumSpace, HilbertSpace, HilbertSpace)`.

    Returns
    -------
    Tensor
        Unfolded tensor on the primitive Brillouin-zone
        [`MomentumSpace`][qten.symbolics.state_space.MomentumSpace] grid with
        primitive [`HilbertSpace`][qten.symbolics.hilbert_space.HilbertSpace]
        matrix axes.

    Raises
    ------
    TypeError
        If `inverse_transform` is not an
        [`InverseBasisTransform`][qten.geometries.basis_transform.InverseBasisTransform],
        if the tensor axes do not have the required symbolic space types, or if
        the momentum axis is not backed by a
        [`ReciprocalLattice`][qten.geometries.spatials.ReciprocalLattice].
    ValueError
        If `tensor` is not rank 3, if the momentum space is empty, or if the
        momentum axis mixes incompatible reciprocal lattices.
    """
    if not isinstance(inverse_transform, InverseBasisTransform):
        raise TypeError(
            "bandunfold requires InverseBasisTransform, "
            f"but got {type(inverse_transform)}"
        )
    if tensor.rank() != 3:
        raise ValueError(
            f"Input tensor must be of rank 3, but has rank {tensor.rank()}"
        )
    if not isinstance(tensor.dims[0], MomentumSpace):
        raise TypeError(
            "The first dimension of the tensor must be a MomentumSpace, "
            f"but is of type {type(tensor.dims[0])}"
        )
    if not isinstance(tensor.dims[1], HilbertSpace):
        raise TypeError(
            "The second dimension of the tensor must be a HilbertSpace, "
            f"but is of type {type(tensor.dims[1])}"
        )
    if not isinstance(tensor.dims[2], HilbertSpace):
        raise TypeError(
            "The third dimension of the tensor must be a HilbertSpace, "
            f"but is of type {type(tensor.dims[2])}"
        )

    k_space = cast(MomentumSpace, tensor.dims[0])
    if not k_space.elements():
        raise ValueError("MomentumSpace is empty")
    lattice_set = set(map(lambda k: k.space, k_space))
    if len(lattice_set) != 1:
        raise ValueError("Invalid BZ")
    folded_reciprocal_lattice = lattice_set.pop()
    if not isinstance(folded_reciprocal_lattice, ReciprocalLattice):
        raise TypeError(
            "Space of momentum should be ReciprocalLattice, but got "
            f"{type(folded_reciprocal_lattice)}"
        )
    folded_reciprocal_lattice = cast(ReciprocalLattice, folded_reciprocal_lattice)
    folded_lattice = folded_reciprocal_lattice.dual

    primitive_lattice = cast(Lattice, inverse_transform(folded_lattice))

    folded_hilbert = cast(HilbertSpace, tensor.dims[2])

    primitive_reciprocal_lattice = primitive_lattice.dual
    primitive_k_space = brillouin_zone(primitive_reciprocal_lattice)

    rebased_states = []
    for psi in folded_hilbert:
        u1_psi = cast(U1Basis, psi)
        rebased_states.append(
            u1_psi.replace(u1_psi.irrep_of(Offset).rebase(primitive_lattice))
        )
    rebased_hilbert = HilbertSpace.new(rebased_states)

    primitive_states: "OrderedDict[U1Basis, int]" = OrderedDict()
    for psi in rebased_states:
        primitive_state = psi.replace(psi.irrep_of(Offset).fractional())
        if primitive_state not in primitive_states:
            primitive_states[primitive_state] = len(primitive_states)
    primitive_hilbert = HilbertSpace(structure=primitive_states)

    # Route each primitive-k sector to its folded-k parent.
    precision = get_precision_config()
    primitive_basis_np = np.array(
        primitive_reciprocal_lattice.basis.evalf(), dtype=precision.np_float
    )
    folded_basis_np = np.array(
        folded_reciprocal_lattice.basis.evalf(), dtype=precision.np_float
    )
    M_rebase = np.linalg.solve(folded_basis_np, primitive_basis_np)
    k_indices = _momentum_match_indices(
        primitive_k_space, k_space, M_rebase, device=tensor.device
    )

    gathered = Tensor(
        data=tensor.data[k_indices.data],
        dims=(primitive_k_space, tensor.dims[1], tensor.dims[2]),
    )
    for dim in (1, 2):
        if gathered.dims[dim] == folded_hilbert:
            gathered = gathered.replace_dim(dim, rebased_hilbert)

    f = fourier_transform(
        primitive_k_space, primitive_hilbert, rebased_hilbert, device=tensor.device
    )
    vratio = np.sqrt(rebased_hilbert.dim / primitive_hilbert.dim)
    f = f / vratio
    unfolded = f @ gathered @ f.h(-2, -1)
    return unfolded

bandfillings

bandfillings(tensor: Tensor, frac: float) -> Tensor

Return eigenvectors for occupied bands up to a filling fraction.

The input tensor is expected to have dimensions (MomentumSpace, HilbertSpace, HilbertSpace), where the MomentumSpace axis indexes momentum sectors and the two HilbertSpace axes form the Hamiltonian matrix at each momentum. The tensor is diagonalized at each momentum, then eigenvectors with energies below the global filling threshold are packed into an output IndexSpace.

Mathematical convention

Each momentum block is diagonalized as \(H(k) V(k) = V(k) E(k)\), and the eigenvectors whose energies fall below the global filling threshold are retained. If frac = f, the target number of occupied states is

\(N_{\mathrm{occ}} = \left\lfloor f\,N_k\,N_b \right\rfloor\), where \(N_k\) is the number of momentum sectors and \(N_b\) is the number of bands per sector. Degenerate states at the threshold are included together.

Degenerate threshold behavior

If one state in a degenerate set is filled, all states in that set are filled. The output index dimension is therefore the maximum number of filled states over all momentum sectors, and sectors with fewer filled states are padded with zeros.

Parameters:

Name Type Description Default
tensor Tensor

Band-resolved tensor with dimensions (MomentumSpace, HilbertSpace, HilbertSpace).

required
frac float

Filling fraction in the inclusive range [0, 1].

required

Returns:

Type Description
Tensor

Eigenvector tensor with dimensions (MomentumSpace, HilbertSpace, IndexSpace). For each momentum sector, columns along IndexSpace contain the eigenvectors selected as filled. The IndexSpace size is the largest filled count among all momentum sectors; sectors with fewer filled bands are padded with zero columns.

Raises:

Type Description
TypeError

If the tensor axes are not MomentumSpace, HilbertSpace, and HilbertSpace, respectively.

ValueError

If tensor is not rank 3. Also raised if frac is outside the inclusive range [0, 1].

Source code in src/qten/bands.py
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
def bandfillings(tensor: Tensor, frac: float) -> Tensor:
    r"""
    Return eigenvectors for occupied bands up to a filling fraction.

    The input tensor is expected to have dimensions
    `(MomentumSpace, HilbertSpace, HilbertSpace)`, where the
    [`MomentumSpace`][qten.symbolics.state_space.MomentumSpace] axis indexes
    momentum sectors and the two
    [`HilbertSpace`][qten.symbolics.hilbert_space.HilbertSpace] axes form the
    Hamiltonian matrix at each momentum. The tensor is diagonalized at each
    momentum, then eigenvectors with energies below the global filling
    threshold are packed into an output
    [`IndexSpace`][qten.symbolics.state_space.IndexSpace].

    Mathematical convention
    -----------------------
    Each momentum block is diagonalized as \(H(k) V(k) = V(k) E(k)\), and the eigenvectors whose energies fall below the global filling threshold
    are retained. If `frac = f`, the target number of occupied states is

    \(N_{\mathrm{occ}} = \left\lfloor f\,N_k\,N_b \right\rfloor\), where \(N_k\) is the number of momentum sectors and \(N_b\) is the number
    of bands per sector. Degenerate states at the threshold are included
    together.

    Degenerate threshold behavior
    -----------------------------
    If one state in a degenerate set is filled, all states in that set are
    filled. The output index dimension is therefore the maximum number of filled
    states over all momentum sectors, and sectors with fewer filled states are
    padded with zeros.

    Parameters
    ----------
    tensor : Tensor
        Band-resolved tensor with dimensions
        `(MomentumSpace, HilbertSpace, HilbertSpace)`.
    frac : float
        Filling fraction in the inclusive range `[0, 1]`.

    Returns
    -------
    Tensor
        Eigenvector tensor with dimensions `(MomentumSpace, HilbertSpace,
        IndexSpace)`. For each momentum sector, columns along `IndexSpace`
        contain the eigenvectors selected as filled. The `IndexSpace` size is
        the largest filled count among all momentum sectors; sectors with fewer
        filled bands are padded with zero columns.

    Raises
    ------
    TypeError
        If the tensor axes are not `MomentumSpace`, `HilbertSpace`, and
        `HilbertSpace`, respectively.
    ValueError
        If `tensor` is not rank 3. Also raised if `frac` is outside the
        inclusive range `[0, 1]`.
    """
    if tensor.rank() != 3:
        raise ValueError(
            f"Input tensor must be of rank 3, but has rank {tensor.rank()}"
        )
    if not isinstance(tensor.dims[0], MomentumSpace):
        raise TypeError("The first dimension of the tensor must be a MomentumSpace.")
    if not isinstance(tensor.dims[1], HilbertSpace):
        raise TypeError("The second dimension of the tensor must be a HilbertSpace.")
    if not isinstance(tensor.dims[2], HilbertSpace):
        raise TypeError("The third dimension of the tensor must be a HilbertSpace.")
    if not (0.0 <= frac <= 1.0):
        raise ValueError(f"Filling fraction must be between 0 and 1, got {frac}")

    kspace = cast(MomentumSpace, tensor.dims[0])
    band_space = cast(HilbertSpace, tensor.dims[1])
    eigvals, eigvecs = eigh(tensor)

    nk, nbands = eigvals.data.shape
    total_states = nk * nbands
    target_fill = int(np.floor(frac * total_states + 1e-12))

    if target_fill <= 0:
        return Tensor(
            data=eigvecs.data[..., :0],
            dims=(kspace, band_space, IndexSpace.linear(0)),
        )
    if target_fill >= total_states:
        return Tensor(
            data=eigvecs.data,
            dims=(kspace, band_space, IndexSpace.linear(nbands)),
        )

    flat_vals = eigvals.data.reshape(-1)
    threshold = torch.kthvalue(flat_vals, target_fill).values
    eps = torch.finfo(eigvals.data.dtype).eps
    tol = (abs(threshold).clamp_min(1.0) * eps * max(nbands, 1) * 8).to(
        eigvals.data.dtype
    )
    filled = eigvals.data <= (threshold + tol)

    counts = filled.sum(dim=1)
    max_fill = int(counts.max().item())
    out_dim = IndexSpace.linear(max_fill)

    order = torch.argsort(filled.to(torch.int8), dim=1, descending=True, stable=True)
    packed = torch.gather(
        eigvecs.data,
        2,
        order[:, None, :].expand(-1, eigvecs.data.shape[1], -1),
    )[..., :max_fill]

    valid = (
        torch.arange(max_fill, device=counts.device)[None, :] < counts[:, None]
    ).to(packed.dtype)
    packed = packed * valid[:, None, :]

    return Tensor(data=packed, dims=(kspace, band_space, out_dim))

bandselect

bandselect(
    tensor: Tensor,
    **kwargs: Dict[
        str,
        Union[
            slice,
            Tuple[int, ...],
            Tuple[float, float],
            Callable[[float], bool],
        ],
    ],
) -> Dict[str, Tensor]

Select specific bands from a band-resolved Tensor based on criteria provided in kwargs.

The input Tensor is diagonalized at each MomentumSpace sector. Each keyword argument defines one named selection criterion, and the returned dictionary maps each name to a tensor containing the matching eigenvectors. Outputs have dimensions (MomentumSpace, HilbertSpace, IndexSpace), where HilbertSpace labels the band basis and IndexSpace labels the selected states for each criterion.

Mathematical convention

For each momentum sector, \(H(k) v_n(k) = \epsilon_n(k) v_n(k)\), and each criterion selects a subset of band labels \(n\). The returned tensor packs the matching eigenvectors \(v_n(k)\) into an IndexSpace, padding sectors with fewer matches by zero columns.

Supported criteria
  • slice: select bands by sorted energy index, such as slice(0, 2) for the two lowest-energy bands.
  • Tuple[int, ...]: select explicit sorted band indices, such as (0, 2) for the lowest and third-lowest bands.
  • Tuple[float, float]: select an inclusive energy range.
  • Callable[[float], bool]: select energies for which the callable returns True.

If a criterion matches no bands in all momentum sectors, the corresponding output tensor has an IndexSpace of dimension zero.

Parameters:

Name Type Description Default
tensor Tensor

Band-resolved tensor with dimensions (MomentumSpace, HilbertSpace, HilbertSpace).

required
kwargs Dict[str, Union[slice, Tuple[int, ...], Tuple[float, float], Callable[[float], bool]]]

Named band-selection criteria.

{}

Returns:

Type Description
Dict[str, Tensor]

Mapping from criterion name to selected eigenvector tensor with dimensions (MomentumSpace, HilbertSpace, IndexSpace).

Raises:

Type Description
TypeError

If the tensor axes are not (MomentumSpace, HilbertSpace, HilbertSpace), or if a criterion has an unsupported type.

ValueError

If tensor is not rank 3.

IndexError

If an explicit integer band index is outside the available band range.

Source code in src/qten/bands.py
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
def bandselect(
    tensor: Tensor,
    **kwargs: Dict[
        str, Union[slice, Tuple[int, ...], Tuple[float, float], Callable[[float], bool]]
    ],
) -> Dict[str, Tensor]:
    r"""
    Select specific bands from a band-resolved [`Tensor`][qten.linalg.tensors.Tensor] based on criteria provided in `kwargs`.

    The input [`Tensor`][qten.linalg.tensors.Tensor] is diagonalized at each
    [`MomentumSpace`][qten.symbolics.state_space.MomentumSpace] sector. Each
    keyword argument defines one named selection criterion, and the returned
    dictionary maps each name to a tensor containing the matching eigenvectors.
    Outputs have dimensions `(MomentumSpace, HilbertSpace, IndexSpace)`, where
    [`HilbertSpace`][qten.symbolics.hilbert_space.HilbertSpace] labels the band
    basis and [`IndexSpace`][qten.symbolics.state_space.IndexSpace] labels the
    selected states for each criterion.

    Mathematical convention
    -----------------------
    For each momentum sector, \(H(k) v_n(k) = \epsilon_n(k) v_n(k)\), and each criterion selects a subset of band labels \(n\). The returned
    tensor packs the matching eigenvectors \(v_n(k)\) into an
    [`IndexSpace`][qten.symbolics.state_space.IndexSpace], padding sectors with
    fewer matches by zero columns.

    Supported criteria
    ------------------
    - `slice`: select bands by sorted energy index, such as `slice(0, 2)` for
      the two lowest-energy bands.
    - `Tuple[int, ...]`: select explicit sorted band indices, such as `(0, 2)`
      for the lowest and third-lowest bands.
    - `Tuple[float, float]`: select an inclusive energy range.
    - `Callable[[float], bool]`: select energies for which the callable returns
      `True`.

    If a criterion matches no bands in all momentum sectors, the corresponding
    output tensor has an `IndexSpace` of dimension zero.

    Parameters
    ----------
    tensor : Tensor
        Band-resolved tensor with dimensions
        `(MomentumSpace, HilbertSpace, HilbertSpace)`.
    kwargs : Dict[str, Union[slice, Tuple[int, ...], Tuple[float, float], Callable[[float], bool]]]
        Named band-selection criteria.

    Returns
    -------
    Dict[str, Tensor]
        Mapping from criterion name to selected eigenvector tensor with
        dimensions `(MomentumSpace, HilbertSpace, IndexSpace)`.

    Raises
    ------
    TypeError
        If the tensor axes are not `(MomentumSpace, HilbertSpace,
        HilbertSpace)`, or if a criterion has an unsupported type.
    ValueError
        If `tensor` is not rank 3.
    IndexError
        If an explicit integer band index is outside the available band range.
    """
    if tensor.rank() != 3:
        raise ValueError(
            f"Input tensor must be of rank 3, but has rank {tensor.rank()}"
        )
    if not isinstance(tensor.dims[0], MomentumSpace):
        raise TypeError("The first dimension of the tensor must be a MomentumSpace.")
    if not isinstance(tensor.dims[1], HilbertSpace):
        raise TypeError("The second dimension of the tensor must be a HilbertSpace.")
    if not isinstance(tensor.dims[2], HilbertSpace):
        raise TypeError("The third dimension of the tensor must be a HilbertSpace.")

    kspace = cast(MomentumSpace, tensor.dims[0])
    band_space = cast(HilbertSpace, tensor.dims[1])
    eigvals, eigvecs = eigh(tensor)
    values = eigvals.data
    vectors = eigvecs.data

    nk, nbands = values.shape
    band_indices = torch.arange(nbands, device=values.device)

    def pack(mask: torch.Tensor) -> Tensor:
        counts = mask.sum(dim=1)
        max_count = int(counts.max().item()) if counts.numel() else 0
        out_dim = IndexSpace.linear(max_count)
        if max_count == 0:
            return Tensor(data=vectors[..., :0], dims=(kspace, band_space, out_dim))

        order = torch.argsort(mask.to(torch.int8), dim=1, descending=True, stable=True)
        packed = torch.gather(
            vectors,
            2,
            order[:, None, :].expand(-1, vectors.shape[1], -1),
        )[..., :max_count]
        valid = (
            torch.arange(max_count, device=counts.device)[None, :] < counts[:, None]
        ).to(packed.dtype)
        packed = packed * valid[:, None, :]
        return Tensor(data=packed, dims=(kspace, band_space, out_dim))

    selected: Dict[str, Tensor] = {}
    for name, criterion in kwargs.items():
        mask: torch.Tensor
        if isinstance(criterion, slice):
            picked = band_indices[criterion]
            mask = torch.zeros((nk, nbands), dtype=torch.bool, device=values.device)
            if picked.numel():
                mask[:, picked] = True
        elif isinstance(criterion, tuple):
            if all(isinstance(x, int) and not isinstance(x, bool) for x in criterion):
                mask = torch.zeros((nk, nbands), dtype=torch.bool, device=values.device)
                if criterion:
                    raw_idx = torch.tensor(
                        criterion, dtype=torch.long, device=values.device
                    )
                    if ((raw_idx < -nbands) | (raw_idx >= nbands)).any():
                        raise IndexError(
                            f"Band index out of range in criterion {name!r}"
                        )
                    mask[:, raw_idx % nbands] = True
            elif len(criterion) == 2 and all(
                isinstance(x, (int, float, np.integer, np.floating))
                and not isinstance(x, bool)
                for x in criterion
            ):
                lo, hi = criterion
                mask = (values >= lo) & (values <= hi)
            else:
                raise TypeError(
                    f"Unsupported tuple criterion for {name!r}: {criterion!r}"
                )
        elif callable(criterion):
            mask = torch.tensor(
                [
                    [bool(criterion(v)) for v in row]
                    for row in values.detach().cpu().tolist()
                ],
                dtype=torch.bool,
                device=values.device,
            )
        else:
            raise TypeError(f"Unsupported criterion for {name!r}: {criterion!r}")

        selected[name] = pack(mask)

    return selected

nearest_bands

nearest_bands(
    h_k: Tensor,
    point: Union[str, Sequence[float]] = "Gamma",
    close_to: float = 0.0,
    tol: float = 1e-06,
    points: Optional[Dict[str, Sequence[float]]] = None,
) -> Tensor

Project a momentum-resolved Hamiltonian onto bands selected at one k-point.

The input h_k is diagonalized at a single anchor momentum \(k_0\). Eigenvectors whose anchor eigenvalues lie within tol of close_to are collected into a rectangular matrix \(V\). If the input Hilbert dimension is \(N\) and \(S\) bands are selected, then V has shape (N, S) and the returned tensor stores \(V^\dagger H(k) V\) for every momentum \(k\).

Projection convention

At the selected anchor sector, the code computes eigenvalues, eigenvectors = torch.linalg.eigh(H_anchor). The columns of eigenvectors with \(|\epsilon_n(k_0) - \mathrm{close\_to}| \le \mathrm{tol}\) form \(V\). The projected block at each momentum is \(H_{\mathrm{proj}}(k) = V^\dagger H(k) V\).

In implementation terms, this projection is the einsum torch.einsum("ia,kab,bj->kij", V_dag, h_k.data, V).

Anchor selection
  • A string point is looked up in points.
  • "Gamma" defaults to the fractional origin when absent from points.
  • A coordinate sequence is interpreted directly as fractional coordinates.
  • Fractional-coordinate differences are wrapped by subtracting the nearest integer, so equivalent periodic coordinates select the same anchor.

If no eigenvalue falls inside the tolerance window, the result has two zero-dimensional IndexSpace axes and data shape (len(kspace), 0, 0).

Notes

The selected subspace is fixed by the anchor momentum only. The same anchor eigenvector matrix \(V\) is applied to every \(H(k)\); this is a projection onto an anchor-defined subspace, not a separately diagonalized band selection at each momentum.

Parameters:

Name Type Description Default
h_k Tensor

Hamiltonian tensor with dims (MomentumSpace, HilbertSpace, HilbertSpace).

required
point str or Sequence[float]

Anchor k-point. String labels are resolved through points, except "Gamma" which defaults to the fractional origin.

"Gamma"
close_to float

Target eigenvalue for the subspace selection.

0.0
tol float

Half-width of the eigenvalue window around close_to.

1e-6
points dict[str, Sequence[float]]

Mapping from labels to fractional coordinates.

None

Returns:

Type Description
Tensor

Projected Hamiltonian with dims (MomentumSpace, IndexSpace, IndexSpace). The last two axes span the selected subspace.

Raises:

Type Description
ValueError

If h_k is not rank 3, if the momentum space is empty, or if the anchor coordinate dimension does not match the momentum-space dimension.

TypeError

If the input dimensions are not (MomentumSpace, HilbertSpace, HilbertSpace).

KeyError

If point is a string other than "Gamma" and is not present in points.

Source code in src/qten/bands.py
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
def nearest_bands(
    h_k: Tensor,
    point: Union[str, Sequence[float]] = "Gamma",
    close_to: float = 0.0,
    tol: float = 1e-6,
    points: Optional[Dict[str, Sequence[float]]] = None,
) -> Tensor:
    r"""
    Project a momentum-resolved Hamiltonian onto bands selected at one k-point.

    The input `h_k` is diagonalized at a single anchor momentum \(k_0\).
    Eigenvectors whose anchor eigenvalues lie within `tol` of `close_to` are
    collected into a rectangular matrix \(V\). If the input Hilbert dimension is
    \(N\) and \(S\) bands are selected, then `V` has shape `(N, S)` and the
    returned tensor stores \(V^\dagger H(k) V\) for every momentum \(k\).

    Projection convention
    ---------------------
    At the selected anchor sector, the code computes
    `eigenvalues, eigenvectors = torch.linalg.eigh(H_anchor)`. The columns of
    `eigenvectors` with \(|\epsilon_n(k_0) - \mathrm{close\_to}| \le \mathrm{tol}\)
    form \(V\). The projected block at each momentum is
    \(H_{\mathrm{proj}}(k) = V^\dagger H(k) V\).

    In implementation terms, this projection is the einsum
    `torch.einsum("ia,kab,bj->kij", V_dag, h_k.data, V)`.

    Anchor selection
    ----------------
    - A string `point` is looked up in `points`.
    - `"Gamma"` defaults to the fractional origin when absent from `points`.
    - A coordinate sequence is interpreted directly as fractional coordinates.
    - Fractional-coordinate differences are wrapped by subtracting the nearest
      integer, so equivalent periodic coordinates select the same anchor.

    If no eigenvalue falls inside the tolerance window, the result has two
    zero-dimensional [`IndexSpace`][qten.symbolics.state_space.IndexSpace] axes
    and data shape `(len(kspace), 0, 0)`.

    Notes
    -----
    The selected subspace is fixed by the anchor momentum only. The same
    anchor eigenvector matrix \(V\) is applied to every \(H(k)\); this is a
    projection onto an anchor-defined subspace, not a separately diagonalized
    band selection at each momentum.

    Parameters
    ----------
    h_k : Tensor
        Hamiltonian tensor with dims
        ([`MomentumSpace`][qten.symbolics.state_space.MomentumSpace],
        [`HilbertSpace`][qten.symbolics.hilbert_space.HilbertSpace],
        [`HilbertSpace`][qten.symbolics.hilbert_space.HilbertSpace]).
    point : str or Sequence[float], default="Gamma"
        Anchor k-point. String labels are resolved through `points`, except
        `"Gamma"` which defaults to the fractional origin.
    close_to : float, default=0.0
        Target eigenvalue for the subspace selection.
    tol : float, default=1e-6
        Half-width of the eigenvalue window around `close_to`.
    points : dict[str, Sequence[float]], optional
        Mapping from labels to fractional coordinates.

    Returns
    -------
    Tensor
        Projected Hamiltonian with dims
        ([`MomentumSpace`][qten.symbolics.state_space.MomentumSpace],
        [`IndexSpace`][qten.symbolics.state_space.IndexSpace],
        [`IndexSpace`][qten.symbolics.state_space.IndexSpace]). The last two
        axes span the selected subspace.

    Raises
    ------
    ValueError
        If `h_k` is not rank 3, if the momentum space is empty, or if the
        anchor coordinate dimension does not match the momentum-space
        dimension.
    TypeError
        If the input dimensions are not
        `(MomentumSpace, HilbertSpace, HilbertSpace)`.
    KeyError
        If `point` is a string other than `"Gamma"` and is not present in
        `points`.
    """
    if h_k.rank() != 3:
        raise ValueError(f"Input tensor must be of rank 3, but has rank {h_k.rank()}")
    if not isinstance(h_k.dims[0], MomentumSpace):
        raise TypeError("The first dimension of the tensor must be a MomentumSpace.")
    if not isinstance(h_k.dims[1], HilbertSpace):
        raise TypeError("The second dimension of the tensor must be a HilbertSpace.")
    if not isinstance(h_k.dims[2], HilbertSpace):
        raise TypeError("The third dimension of the tensor must be a HilbertSpace.")

    kspace = cast(MomentumSpace, h_k.dims[0])
    k_items = list(kspace.structure.items())
    if not k_items:
        raise ValueError("MomentumSpace is empty")

    dim = k_items[0][0].space.dim
    if isinstance(point, str):
        if points is not None and point in points:
            target_frac = tuple(float(x) for x in points[point])
        elif point == "Gamma":
            target_frac = tuple(0.0 for _ in range(dim))
        else:
            raise KeyError(
                f"Point {point!r} not found in `points`; "
                "provide a `points` mapping or pass explicit fractional coordinates."
            )
    else:
        target_frac = tuple(float(x) for x in point)
    if len(target_frac) != dim:
        raise ValueError(
            f"Anchor point has {len(target_frac)} coordinates but momentum "
            f"space has dimension {dim}."
        )

    precision = get_precision_config()
    k_frac = np.array(
        [[float(k.rep[j, 0]) for j in range(dim)] for k, _ in k_items],
        dtype=precision.np_float,
    )
    k_indices = np.array([idx for _, idx in k_items], dtype=np.int64)
    target_arr = np.asarray(target_frac, dtype=precision.np_float)
    diff = k_frac - target_arr
    diff = diff - np.round(diff)
    dist = np.linalg.norm(diff, axis=1)
    best_row = int(np.argmin(dist))
    anchor_idx = int(k_indices[best_row])

    H_anchor = h_k.data[anchor_idx]
    eigenvalues, eigenvectors = torch.linalg.eigh(H_anchor)

    mask = (eigenvalues - close_to).abs() <= tol
    selected = torch.nonzero(mask, as_tuple=False).flatten()
    n_selected = int(selected.numel())

    V = eigenvectors.index_select(-1, selected)  # (N, H)
    V_dag = V.conj().transpose(-2, -1)  # (H, N)

    projected = torch.einsum("ia,kab,bj->kij", V_dag, h_k.data, V)

    out_space = IndexSpace.linear(n_selected)
    return Tensor(data=projected, dims=(kspace, out_space, out_space))