YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
Pretrained Foundation Neural Quantum States on the disordered Heisenberg model on a square lattice.
For more details about the model, refer to the manuscript: https://arxiv.org/abs/2507.05073
Use the code below to get started with the model. In particular, we sample the model using NetKet.
from functools import partial
import jax
import jax.numpy as jnp
import netket as nk
import math
import flax
from flax.training import checkpoints
import numpy as np
from netket.operator.spin import sigmax, sigmaz, sigmay
flax.config.update('flax_use_orbax_checkpointing', False)
p = 0.7 #* fix the value of the external field
L = 6
revision = f"L{L}_p{p}"
def edges_square_lattice(L):
Ns = L*L
indices = np.arange(Ns)
indices_right = (indices+1)%L + L*(indices//L)
indices_down = (indices+L)%Ns
first = np.c_[indices, indices_right]
second = np.c_[indices, indices_down]
edges = np.concatenate([first, second], axis=0)
return edges
def coupling_heis_random(random_J, edges):
edges_with_random_vars = list(zip(edges, random_J))
return edges_with_random_vars
def si_sj(hi, i, j, txy=1.0):
# 0.25 factor is to take into account for spin operators
return 0.25*(txy * (sigmax(hi, i) * sigmax(hi, j) + sigmay(hi, i) * sigmay(hi, j)) + sigmaz(hi, i) * sigmaz(hi, j))
def heisenberg_hamiltonian(edges_Js, hi, txy=1.0):
ham = 0.0
for (ij, J) in edges_Js:
ham += J * si_sj(hi, ij[0], ij[1], txy)
return ham
from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/heisenberg_disorder_fnqs",
trust_remote_code=True,
revision=revision)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)
lattice = nk.graph.Hypercube(length=L, n_dim=2, pbc=True)
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0)
# Random Heisenberg Hamiltonian
from huggingface_hub import hf_hub_download
coups_path = hf_hub_download(repo_id="nqs-models/heisenberg_disorder_fnqs", filename="coups", revision=revision)
random_J = np.loadtxt(coups_path)[0]
edges = edges_square_lattice(L)
edges_Js = coupling_heis_random(random_J=random_J, edges=edges)
N_mc = 6000
hamiltonian = heisenberg_hamiltonian(edges_Js, hilbert)
sampler = nk.sampler.MetropolisExchange(hilbert=hilbert,
graph=lattice,
d_max=2,
n_chains=N_mc,
sweep_size=lattice.n_nodes)
key = jax.random.key(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler,
apply_fun=partial(wf.__call__, coups=random_J),
sampler_seed=subkey,
n_samples=N_mc,
n_discard_per_chain=0,
variables=wf.params,
chunk_size=N_mc)
path = hf_hub_download(repo_id="nqs-models/heisenberg_disorder_fnqs", filename="spins", revision=revision)
samples = checkpoints.restore_checkpoint(path, target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(ฯ = samples)
import time
# Sample the model
for _ in range(10):
start = time.time()
E = vstate.expect(hamiltonian)
vstate.sample()
print("Mean: ", E.mean.real / lattice.n_nodes, "\t time=", time.time()-start, flush=True)
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support