YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Pretrained Foundation Neural Quantum States on the disordered Heisenberg model on a square lattice.

For more details about the model, refer to the manuscript: https://arxiv.org/abs/2507.05073

Use the code below to get started with the model. In particular, we sample the model using NetKet.

from functools import partial
import jax
import jax.numpy as jnp
import netket as nk
import math
import flax
from flax.training import checkpoints
import numpy as np
from netket.operator.spin import sigmax, sigmaz, sigmay

flax.config.update('flax_use_orbax_checkpointing', False)

p = 0.7 #* fix the value of the external field
L = 6
revision = f"L{L}_p{p}"

def edges_square_lattice(L):
    Ns = L*L
    indices = np.arange(Ns)
    indices_right = (indices+1)%L + L*(indices//L)
    indices_down = (indices+L)%Ns
    first = np.c_[indices, indices_right]
    second = np.c_[indices, indices_down]

    edges = np.concatenate([first, second], axis=0)
    return edges

def coupling_heis_random(random_J, edges):
    edges_with_random_vars = list(zip(edges, random_J))
    return edges_with_random_vars

def si_sj(hi, i, j, txy=1.0):
    # 0.25 factor is to take into account for spin operators
    return 0.25*(txy * (sigmax(hi, i) * sigmax(hi, j) + sigmay(hi, i) * sigmay(hi, j)) + sigmaz(hi, i) * sigmaz(hi, j))

def heisenberg_hamiltonian(edges_Js, hi, txy=1.0):
    ham = 0.0
    for (ij, J) in edges_Js:
        ham += J * si_sj(hi,  ij[0], ij[1], txy)
    return ham

from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/heisenberg_disorder_fnqs", 
                                   trust_remote_code=True, 
                                   revision=revision)

N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)

lattice = nk.graph.Hypercube(length=L, n_dim=2, pbc=True)
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0)

# Random Heisenberg Hamiltonian
from huggingface_hub import hf_hub_download
coups_path = hf_hub_download(repo_id="nqs-models/heisenberg_disorder_fnqs", filename="coups", revision=revision)
random_J = np.loadtxt(coups_path)[0]
edges = edges_square_lattice(L)
edges_Js = coupling_heis_random(random_J=random_J, edges=edges)

N_mc = 6000

hamiltonian = heisenberg_hamiltonian(edges_Js, hilbert)
sampler = nk.sampler.MetropolisExchange(hilbert=hilbert,
                                        graph=lattice,
                                        d_max=2,
                                        n_chains=N_mc,
                                        sweep_size=lattice.n_nodes)

key = jax.random.key(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler, 
                        apply_fun=partial(wf.__call__, coups=random_J), 
                        sampler_seed=subkey,
                        n_samples=N_mc, 
                        n_discard_per_chain=0,
                        variables=wf.params,
                        chunk_size=N_mc)

path = hf_hub_download(repo_id="nqs-models/heisenberg_disorder_fnqs", filename="spins", revision=revision)
samples = checkpoints.restore_checkpoint(path, target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(ฯƒ = samples)

import time
# Sample the model
for _ in range(10):
    start = time.time()
    E = vstate.expect(hamiltonian)
    vstate.sample()

    print("Mean: ", E.mean.real / lattice.n_nodes, "\t time=", time.time()-start, flush=True)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support