==================================================================================================== """ TODOs - flex attention (or other) kernel, preferably block-sparse for packing reasons - if varlen, attn window warmup - experiment with d_model/d_head/attn variants to fix low arithmetic intensity for attn - async data loading (and comms in general) - why ~50% hbm util? look for unnecessary syncs - optimizers - sharding pytrees across devices - tune lr schedule/n_train_iters in current script (and general hparam sweeps across changes) - figure out training with value embeddings (seems to increase step time disproportionately) """ import os import sys import glob import time import uuid import dataclasses import datetime import jax jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) jax.config.update("jax_persistent_cache_enable_xla_caches", "all") import jax.numpy as jnp from jax import jit, value_and_grad from jax.lax import fori_loop, rsqrt, cond, scan from jax.random import PRNGKey, split, normal from jax.tree_util import ( tree_map, tree_leaves, tree_map_with_path, tree_flatten, tree_unflatten, DictKey, ) from jax.sharding import PartitionSpec as P, Mesh, NamedSharding, AxisType from jax.nn import initializers, relu, log_softmax from jax.nn import dot_product_attention import einops import pickle import numpy as np from functools import reduce, partial import itertools from dataclasses import dataclass, field from typing import Any, Dict, List, Iterator, NamedTuple, Callable, Union import collections.abc PyTree = Any # ======================== utils ======================== class Logger: def __init__(self): self.run_id = None self.logdir = None self.logfile = None self.is_master = jax.process_index() == 0 if not self.is_master: return self.run_id = str(uuid.uuid4()) self.logdir = f"logs/{self.run_id}/" os.makedirs(self.logdir, exist_ok=True) self.logfile = f"logs/{self.run_id}.txt" self.prev_metrics = None with open(self.logfile, "w") as f: with open(sys.argv[0]) as f2: code = f2.read() f.write("=" * 100 + "\n" + code + "\n" + "=" * 100 + "\n") def msg(self, msg: str): if not self.is_master: return print(msg) with open(self.logfile, "a") as f: f.write("[MESSAGE] " + str(msg) + "\n") def log(self, metrics: dict): if not self.is_master: return metrics, self.prev_metrics = self.prev_metrics, metrics if metrics is None: return metrics = " | ".join( list(itertools.starmap("{}: {}".format, metrics.items())) ) print(metrics) with open(self.logfile, "a") as f: f.write("[METRICS (1 step stale)] " + str(metrics) + "\n") def flush(self): if not self.is_master: return metrics = self.prev_metrics self.prev_metrics = None if metrics is None: return metrics = " | ".join( list(itertools.starmap("{}: {}".format, metrics.items())) ) print(metrics) with open(self.logfile, "a") as f: f.write("[METRICS (latest)] " + str(metrics) + "\n") def dump(self, step: int, params: PyTree, opt_state: PyTree, config): if not self.is_master: return params_host = jax.device_get(params) opt_state_host = jax.device_get(opt_state) state_to_save = { "step": step, "params": params_host, "opt_state": opt_state_host, "config": config, } save_path = f"{self.logdir}/state_step{step:06d}.pkl" with open(save_path, "wb") as f: pickle.dump(state_to_save, f) self.msg(f"Saved checkpoint to {save_path}") def filter_pytree(pytree: PyTree, condition_map: Any) -> PyTree | None: if condition_map is True: return pytree if not condition_map: return None if isinstance(pytree, collections.abc.Mapping) and isinstance( condition_map, collections.abc.Mapping ): filtered_dict = { k: subtree for k, sub_map in condition_map.items() if k in pytree and (subtree := filter_pytree(pytree[k], sub_map)) is not None } return pytree.__class__(filtered_dict) if isinstance(pytree, (list, tuple)) and isinstance(condition_map, (list, tuple)): filtered_list = [ subtree for item, sub_map in zip(pytree, condition_map) if (subtree := filter_pytree(item, sub_map)) is not None ] return type(pytree)(filtered_list) return None # ====================== training config ========================= @dataclass(kw_only=True, frozen=True) @jax.tree_util.register_static class Config: # mesh mesh_axis_names: tuple[str, ...] = ("dp",) mesh_shape: tuple[int, ...] = () # paths input_bin: str = "fineweb10B/fineweb_train_*.bin" input_val_bin: str = "fineweb10B/fineweb_val_*.bin" # iteration handling n_train_iters: int = 1675 n_warmup_iters: int = 0 f_warmdown_iters: float = 0.4 n_warmdown_iters: int = 0 val_loss_every: int = 125 val_tokens: int = 10485760 save_every: int = 0 # input sizes batch_size: int = 8 * 64 # batch size for min_sequence_length micro_batch_size: int = 16 min_sequence_length: int = 1024 # max_sequence_length: int = 2048 - 512 max_sequence_length: int = 2048 sequence_warmup_intervals: int = 1024 # init seed: int = 42 # eps adam_eps: float = 1e-10 # adam for embeddings adam_embed_base_lr: float = 0.6 adam_embed_beta1: float = 0.9 adam_embed_beta2: float = 0.95 # adam for lm head adam_lm_head_base_lr: float = 0.008 adam_lm_head_beta1: float = 0.9 adam_lm_head_beta2: float = 0.95 # muon for matrices muon_base_lr: float = 0.04 muon_momentum_warmup_steps: int = 500 muon_warmup_momentum_init: float = 0.85 muon_warmup_momentum_final: float = 0.95 muon_ns_iters: int = 5 muon_eps: float = 1e-7 # adam for non-matrices adam_nonmat_base_lr: float = 0.04 adam_nonmat_beta1: float = 0.9 adam_nonmat_beta2: float = 0.95 # arch n_layers: int = 12 d_model: int = 1024 n_heads: int = 4 d_head: int = 0 logit_softcap: float = 15.0 rope_base: float = 1024 vocab_size: int = 50304 dtype: str = "bfloat16" # sharding weight_sharding = None activation_sharding = ( None, "dp", ) # gradient accumulation axis, micro-batch axis, remaining are automatically None attention_sharding = ("dp", None, None, None) def __post_init__(self): object.__setattr__(self, "mesh_shape", (jax.device_count(),)) assert self.batch_size % self.micro_batch_size == 0 object.__setattr__( self, "n_warmdown_iters", int(self.n_train_iters * self.f_warmdown_iters)) assert self.d_model % self.n_heads == 0 object.__setattr__(self, "d_head", self.d_model // self.n_heads) assert self.n_layers % 2 == 0 def get_mesh(config: Config): mesh = jax.make_mesh(config.mesh_shape, config.mesh_axis_names) return mesh # ====================== optimizers ================== class Optimizer(NamedTuple): init: Callable update: Callable def get_lr(it, n_warmup_iters, n_warmdown_iters, n_train_iters): warmup_lr = (it + 1) / n_warmup_iters constant_lr = 1.0 warmdown_lr = (n_train_iters - it) / n_warmdown_iters * (1.0 - 0.1) + 0.1 lr = jnp.where( it < n_warmup_iters, warmup_lr, jnp.where(it < n_train_iters - n_warmdown_iters, constant_lr, warmdown_lr), ) return lr def adam( base_lr: float, b1: float, b2: float, n_warmup_iters: int, n_warmdown_iters: int, n_train_iters: int, adam_eps: float, ): def init(params): m = tree_map(jnp.zeros_like, params) v = tree_map(jnp.zeros_like, params) step = jnp.array(0, dtype=jnp.int32) return {"m": m, "v": v, "step": step} def update(grads, params, state): step = state["step"] + 1 lr = base_lr * get_lr( state["step"], n_warmup_iters, n_warmdown_iters, n_train_iters ) m = tree_map(lambda m, g: (b1 * m + (1 - b1) * g).astype(m.dtype), state["m"], grads) v = tree_map(lambda v, g: (b2 * v + (1 - b2) * g**2).astype(v.dtype), state["v"], grads) m_hat = tree_map(lambda m: (m / (1 - b1**step)).astype(m.dtype), m) v_hat = tree_map(lambda v: (v / (1 - b2**step)).astype(v.dtype), v) updates = tree_map(lambda m, v: lr * m / (jnp.sqrt(v) + adam_eps), m_hat, v_hat) new_params = tree_map(lambda p, u: p - u.astype(p.dtype), params, updates) new_state = {"m": m, "v": v, "step": step} return new_params, new_state return Optimizer(init, update) def zeropower_via_newtonschulz5(G, steps, eps): assert len(G.shape) == 2 transpose = G.shape[0] > G.shape[1] def _update_loop(X): a, b, c = (3.4445, -4.7750, 2.0315) for i in range(steps): A = X @ X.T B = b * A + c * (A @ A) X = a * X + B @ X return X def tall_case(g): X = g.T.astype(jnp.bfloat16) X /= jnp.linalg.norm(X) + eps X_final = _update_loop(X) return X_final.T.astype(g.dtype) def wide_case(g): X = g.astype(jnp.bfloat16) X /= jnp.linalg.norm(X) + eps X_final = _update_loop(X) return X_final.astype(g.dtype) return cond(transpose, tall_case, wide_case, G) def muon( base_lr: float, momentum_warmup_steps: int, warmup_momentum_init: float, warmup_momentum_final: float, n_warmup_iters: int, n_warmdown_iters: int, n_train_iters: int, ns_iters: int, eps: float, ): def init(params): m = tree_map(jnp.zeros_like, params) step = jnp.array(0, dtype=jnp.int32) return {"m": m, "step": step} def update(grads, params, state): step = state["step"] lr = base_lr * get_lr(step, n_warmup_iters, n_warmdown_iters, n_train_iters) frac = jnp.minimum(step / momentum_warmup_steps, 1.0) momentum = (warmup_momentum_init + frac * ( warmup_momentum_final - warmup_momentum_init )) new_m = tree_map(lambda m, g: (m + (1 - momentum).astype(m.dtype) * (g - m)).astype(m.dtype), state["m"], grads) def _update_leaf(g, p, m): g_nesterov = g + momentum.astype(m.dtype) * (m - g) update = ( lr.astype(p.dtype) * zeropower_via_newtonschulz5(g_nesterov, ns_iters, eps).astype(p.dtype) * jnp.sqrt(jnp.maximum(1.0, g.shape[0] / g.shape[1])).astype(p.dtype) ) return p - update new_params = tree_map(_update_leaf, grads, params, new_m) new_state = {"m": new_m, "step": step + 1} return new_params, new_state return Optimizer(init, update) def multi_optimizer(optimizer_map: Any, **optimizers: Optimizer): optimizer_names = list(optimizers.keys()) def init(params): states = {} for name, opt in optimizers.items(): is_relevant_map = tree_map(lambda label: label == name, optimizer_map) params_subset = filter_pytree(params, is_relevant_map) states[name] = opt.init(params_subset) return states def update(grads, params, states): leaves, treedef = tree_flatten(params) map_leaves, _ = tree_flatten(optimizer_map) new_leaves_list = list(leaves) new_states = {} for name, opt in optimizers.items(): is_relevant_map = tree_map(lambda label: label == name, optimizer_map) grads_subset = filter_pytree(grads, is_relevant_map) params_subset = filter_pytree(params, is_relevant_map) if not grads_subset or not tree_leaves(grads_subset): new_states[name] = states[name] continue new_params_subset, new_states[name] = opt.update( grads_subset, params_subset, states[name] ) subset_leaves, _ = tree_flatten(new_params_subset) original_indices = [ i for i, label in enumerate(map_leaves) if label == name ] assert len(subset_leaves) == len( original_indices ), f"Mismatch for optimizer {name}" for idx, leaf_val in zip(original_indices, subset_leaves): new_leaves_list[idx] = leaf_val new_params = tree_unflatten(treedef, new_leaves_list) return new_params, new_states return Optimizer(init, update) def create_optimizer_map(params): def get_label(path, leaf): is_adam_embed = any(isinstance(k, DictKey) and k.key == "wte" for k in path) is_adam_lm_head = any( isinstance(k, DictKey) and k.key == "lm_head" for k in path ) is_muon = ( any(isinstance(k, DictKey) and k.key == "h" for k in path) and leaf.ndim == 2 ) is_adam_nonmat = any( isinstance(k, DictKey) and k.key == "skip_weights" for k in path ) or ( leaf.ndim < 2 and any(isinstance(k, DictKey) and k.key == "h" for k in path) ) assert ( int(is_adam_embed) + int(is_adam_lm_head) + int(is_muon) + int(is_adam_nonmat) == 1 ) if is_adam_embed: return "adam_embed" elif is_adam_lm_head: return "adam_lm_head" elif is_muon: return "muon" else: return "adam_nonmat" return tree_map_with_path(get_label, params) # ======================== dataset ============================= def _get_shape_for_step(step: int, config: Config): available_seq_lens = np.arange( config.min_sequence_length, config.max_sequence_length + 1, config.sequence_warmup_intervals, ) if config.max_sequence_length not in available_seq_lens: available_seq_lens = np.append(available_seq_lens, config.max_sequence_length) n_lens = len(available_seq_lens) idx = int((step / config.n_train_iters) * n_lens) current_seq_len = available_seq_lens[idx] # progress = step / max(1, config.n_train_iters - 1) # target_len = config.min_sequence_length + progress * ( # config.max_sequence_length - config.min_sequence_length # ) # current_seq_len = available_seq_lens[ # np.argmin(np.abs(available_seq_lens - target_len)) # ] total_tokens = config.batch_size * config.min_sequence_length current_B = total_tokens // current_seq_len current_B = (current_B // config.micro_batch_size) * config.micro_batch_size current_B = max(current_B, config.micro_batch_size) current_n_grad_acc = current_B // config.micro_batch_size return int(current_seq_len), int(current_B), int(current_n_grad_acc) def load_dataset( config: Config, logger: Logger, mesh: Mesh, is_training: bool ) -> List[tuple[jax.Array, jax.Array]]: def _load_data_shard(filename): with open(filename, "rb") as f: header = np.frombuffer(f.read(256 * 4), dtype=np.int32) assert header[0] == 20240520, f"Magic number mismatch in {filename}" assert header[1] == 1, f"Unsupported version in {filename}" ntok = header[2] tokens = np.frombuffer(f.read(), dtype=np.uint16) assert len(tokens) == ntok, f"Token count mismatch in {filename}" return tokens process_rank = jax.process_index() num_processes = jax.process_count() files = sorted(glob.glob(config.input_bin)) if not files: raise RuntimeError(f"No files found for pattern {config.input_bin}") logger.msg( f"Process {process_rank}/{num_processes} starting data pre-loading into RAM..." ) all_tokens_list = [_load_data_shard(f) for f in files] all_tokens = np.concatenate(all_tokens_list) logger.msg( f"Process {process_rank}/{num_processes} finished loading {all_tokens.nbytes / 1e9:.2f} GB of tokens." ) shape_schedule = [] if is_training: for step in range(config.n_train_iters): seq_len, B, _ = _get_shape_for_step(step, config) shape_schedule.append({"seq_len": seq_len, "B": B}) num_global_batches = config.n_train_iters else: seq_len = config.max_sequence_length total_tokens = config.batch_size * config.min_sequence_length batch_size = ( total_tokens // seq_len // config.micro_batch_size ) * config.micro_batch_size batch_size = max(batch_size, config.micro_batch_size) total_tokens_per_batch = batch_size * seq_len num_global_batches = config.val_tokens // total_tokens_per_batch for _ in range(num_global_batches): shape_schedule.append({"seq_len": seq_len, "B": batch_size}) precomputed_batches = [] token_cursor = 0 activation_sharding = NamedSharding(mesh, P(*config.activation_sharding)) for global_step_idx in range(num_global_batches): shape_info = shape_schedule[global_step_idx] tokens_for_this_batch = shape_info["B"] * shape_info["seq_len"] if global_step_idx % num_processes == process_rank: seq_len = shape_info["seq_len"] batch_size = shape_info["B"] n_grad_acc_steps = batch_size // config.micro_batch_size start_idx = token_cursor end_idx = start_idx + tokens_for_this_batch + 1 if end_idx > len(all_tokens): if process_rank == 0: logger.msg("Cycling dataset...") token_cursor = 0 start_idx = 0 end_idx = tokens_for_this_batch + 1 if end_idx > len(all_tokens): raise RuntimeError(f"Not enough tokens ({len(all_tokens)}) to form even one batch of size {tokens_for_this_batch+1}.") buf = all_tokens[start_idx:end_idx] x = np.array(buf[:-1], dtype=np.int32).reshape(batch_size, seq_len) y = np.array(buf[1:], dtype=np.int32).reshape(batch_size, seq_len) batched_x = einops.rearrange( x, "(a b) ... -> a b ...", a=n_grad_acc_steps ) batched_y = einops.rearrange( y, "(a b) ... -> a b ...", a=n_grad_acc_steps ) batched_x = jax.device_put(batched_x, activation_sharding) batched_y = jax.device_put(batched_y, activation_sharding) precomputed_batches.append((batched_x, batched_y)) token_cursor += tokens_for_this_batch logger.msg( f"Process {process_rank}/{num_processes} pre-computed {len(precomputed_batches)} batches." ) if num_global_batches > 0 and not precomputed_batches: raise RuntimeError(f"Process {process_rank} could not create any batches. " "Check data size, batch configuration, and number of processes.") return precomputed_batches # ======================== inits ============================= def precompute_rope(config: Config, mesh: Mesh) -> PyTree: weight_sharding = NamedSharding(mesh, P(config.weight_sharding)) dim = config.d_head seq_len = config.max_sequence_length inv_freq = 1.0 / ( config.rope_base ** (jnp.arange(0, dim // 4, dtype=jnp.float32) / (dim // 4)) ) inv_freq = jnp.concatenate([inv_freq, jnp.zeros_like(inv_freq)]) t = jnp.arange(seq_len) freqs = jnp.outer(t, inv_freq) cos = jnp.cos(freqs).astype(config.dtype) sin = jnp.sin(freqs).astype(config.dtype) precomputed_params = {} precomputed_params["cos"] = jax.device_put(cos[:, None, :], weight_sharding) precomputed_params["sin"] = jax.device_put(sin[:, None, :], weight_sharding) return precomputed_params def init_params(config: Config, mesh: Mesh) -> PyTree: weight_sharding = NamedSharding(mesh, P(config.weight_sharding)) def sharded_zeros(shape): arr = jnp.zeros(shape, dtype=config.dtype) return jax.device_put(arr, weight_sharding) def sharded_ones(shape): arr = jnp.ones(shape, dtype=config.dtype) return jax.device_put(arr, weight_sharding) def sharded_normal(key, shape, std): arr = jax.random.normal(key, shape, dtype=config.dtype) * std return jax.device_put(arr, weight_sharding) def sharded_uniform(key, shape, bound): arr = jax.random.uniform( key, shape, dtype=config.dtype, minval=-bound, maxval=bound ) return jax.device_put(arr, weight_sharding) root_key = jax.random.key(seed=config.seed) key = map(partial(jax.random.fold_in, root_key), itertools.count()) params = dict() params["wte"] = sharded_normal(next(key), (config.vocab_size, config.d_model), 1.0) params["h"] = [] params["skip_weights"] = sharded_ones(config.n_layers // 2) params["lm_head"] = sharded_zeros((config.d_model, config.vocab_size)) for i in range(config.n_layers): block_params = dict() block_params["attn"] = dict() block_params["attn"]["c_qkv"] = sharded_uniform( next(key), (3 * config.d_model, config.d_model), (0.75 / config.d_model) ** 0.5, ) block_params["attn"]["c_proj"] = sharded_zeros((config.d_model, config.d_model)) block_params["attn"]["lamb"] = jnp.array(0.5, dtype=config.dtype) block_params["attn"]["scale"] = jnp.array(0.12, dtype=config.dtype) block_params["mlp"] = dict() block_params["mlp"]["c_fc"] = sharded_uniform( next(key), (config.d_model, 4 * config.d_model), (0.75 / config.d_model) ** 0.5, ) block_params["mlp"]["c_proj"] = sharded_zeros( (4 * config.d_model, config.d_model) ) lambdas_arr = jnp.array([1.0, 0.0], dtype=config.dtype) block_params["lambdas"] = jax.device_put(lambdas_arr, weight_sharding) params["h"].append(block_params) precomputed_params = precompute_rope(config, mesh) return params, precomputed_params def init_optimizer(config: Config, params: PyTree, mesh: Mesh): optimizer_map = create_optimizer_map(params) adam_embed = adam( config.adam_embed_base_lr, config.adam_embed_beta1, config.adam_embed_beta2, config.n_warmup_iters, config.n_warmdown_iters, config.n_train_iters, config.adam_eps, ) adam_lm_head = adam( config.adam_lm_head_base_lr, config.adam_lm_head_beta1, config.adam_lm_head_beta2, config.n_warmup_iters, config.n_warmdown_iters, config.n_train_iters, config.adam_eps, ) adam_nonmat = adam( config.adam_nonmat_base_lr, config.adam_nonmat_beta1, config.adam_nonmat_beta2, config.n_warmup_iters, config.n_warmdown_iters, config.n_train_iters, config.adam_eps, ) muon_opt = muon( config.muon_base_lr, config.muon_momentum_warmup_steps, config.muon_warmup_momentum_init, config.muon_warmup_momentum_final, config.n_warmup_iters, config.n_warmdown_iters, config.n_train_iters, config.muon_ns_iters, config.muon_eps, ) optimizer = multi_optimizer( optimizer_map, adam_embed=adam_embed, adam_lm_head=adam_lm_head, muon=muon_opt, adam_nonmat=adam_nonmat, ) opt_state = optimizer.init(params) return optimizer, opt_state # ======================= model and loss ======================= def rms_norm(x, config): return x * rsqrt( jnp.mean(jnp.square(x), axis=-1, keepdims=True) + jnp.finfo(x.dtype).eps ) def apply_rotary_emb(x, cos, sin): d = x.shape[-1] // 2 x1, x2 = x[..., :d], x[..., d:] y1 = x1 * cos - x2 * sin y2 = x1 * sin + x2 * cos return jnp.concatenate([y1, y2], axis=-1).astype(x.dtype) def linear(x, weight): return jnp.einsum("...i,io->...o", x, weight.astype(x.dtype)) def attention_forward(params, x, v1, cos, sin, config): B, T, C = x.shape params_qkv = einops.rearrange( params["c_qkv"], "(three h d) c -> three h d c", d=config.d_head, three=3 ) q, k, v = einops.einsum(x, params_qkv, "b t c, three h d c -> three b t h d") if v1 is None: v1 = v v = (1 - params["lamb"]) * v + params["lamb"] * v1.reshape(v.shape) q = apply_rotary_emb(rms_norm(q, config), cos, sin) k = apply_rotary_emb(rms_norm(k, config), cos, sin) y = dot_product_attention(q, k, v, scale=params["scale"], is_causal=True).reshape( B, T, C ) y = linear(y, params["c_proj"]) return y, v1 def mlp_forward(params, x): x = linear(x, params["c_fc"]) x = relu(x) ** 2 x = linear(x, params["c_proj"]) return x def block_forward(params, x, v1, x0, cos, sin, config): x = params["lambdas"][0] * x + params["lambdas"][1] * x0 x1, v1 = attention_forward( params["attn"], rms_norm(x, config), v1, cos, sin, config ) x = x + x1 x = x + mlp_forward(params["mlp"], rms_norm(x, config)) return x, v1 def gpt_forward(params, idx, precomputed_params, config): _, T = idx.shape x = params["wte"][idx] x = rms_norm(x, config) x0 = x v1 = None skip_connections = [] n_encoder_layers = config.n_layers // 2 n_decoder_layers = config.n_layers - n_encoder_layers cos = precomputed_params["cos"][:T, :, :] sin = precomputed_params["sin"][:T, :, :] for i in range(n_encoder_layers): x, v1 = block_forward(params["h"][i], x, v1, x0, cos, sin, config) skip_connections.append(x) for i in range(n_decoder_layers): x = x + params["skip_weights"][i] * skip_connections.pop() x, v1 = block_forward( params["h"][n_encoder_layers + i], x, v1, x0, cos, sin, config ) x = rms_norm(x, config) logits = linear(x, params["lm_head"]) logits = (2.0 * config.logit_softcap) * jax.nn.sigmoid( logits / (config.logit_softcap / 2.0) ) return logits.astype(jnp.float32) def loss_fn(params, batch, precomputed_params, config): idx, labels = batch logits = gpt_forward(params, idx, precomputed_params, config) axis = logits.ndim - 1 label_logits = jnp.take_along_axis( logits, jnp.expand_dims(labels, axis), axis=axis ).take(0, axis=axis) log_normalizers = jax.nn.logsumexp(logits, axis=axis) return jnp.mean(log_normalizers - label_logits) # ======================== training ============================ def train_step( config: Config, params: PyTree, precomputed_params: PyTree, opt_state: PyTree, optimizer: Optimizer, batched_x: jax.Array, batched_y: jax.Array, ) -> tuple[PyTree, PyTree, dict]: n_grad_acc_steps = batched_x.shape[0] def loss_and_grad_fn(p, micro_batch): return value_and_grad(loss_fn)(p, micro_batch, precomputed_params, config) def micro_step(carry, micro_batch): accum_grads, total_loss = carry loss, grads = loss_and_grad_fn(params, micro_batch) new_accum_grads = tree_map(jnp.add, accum_grads, grads) return (new_accum_grads, total_loss + loss), None zero_grads = tree_map(jnp.zeros_like, params) init_carry = (zero_grads, 0.0) (final_grads_accum, total_loss), _ = scan( micro_step, init_carry, (batched_x, batched_y) ) avg_loss = total_loss / n_grad_acc_steps final_grads = tree_map(lambda g: (g / n_grad_acc_steps).astype(g.dtype), final_grads_accum) new_params, new_opt_state = optimizer.update(final_grads, params, opt_state) return new_params, new_opt_state, {"loss": avg_loss} def eval_step( params: PyTree, batched_x: jax.Array, batched_y: jax.Array, precomputed_params: PyTree, config: Config, ) -> jax.Array: n_grad_acc_steps = batched_x.shape[0] def loss_loop_body(i, accumulated_loss): micro_batch = (batched_x[i], batched_y[i]) loss = loss_fn(params, micro_batch, precomputed_params, config) return accumulated_loss + loss total_loss = fori_loop(0, n_grad_acc_steps, loss_loop_body, 0.0) avg_loss = total_loss / n_grad_acc_steps return avg_loss def run_evaluation( step: int, config: Config, params: PyTree, val_loader: Iterator, precomputed_params: PyTree, mesh: Mesh, logger: Logger, compiled_eval_fn: Callable, ): logger.msg(f"Running validation for step {step}...") val_loss_accum = 0.0 val_steps = 0 for batched_x, batched_y in val_loader: loss = compiled_eval_fn(params, batched_x, batched_y, precomputed_params) val_loss_accum += loss val_steps += 1 if val_steps == 0: if step == config.val_loss_every or step >= config.n_train_iters -1: logger.msg( "Warning: Validation loader was empty, no validation was run." ) return final_val_loss = val_loss_accum / val_steps logger.log({"step": step, "val_loss": final_val_loss}) logger.msg(f"Validation finished for step {step}.") def train_loop(config: Config): logger = Logger() mesh = get_mesh(config) with mesh: params, precomputed_params = init_params(config, mesh) optimizer, opt_state = init_optimizer(config, params, mesh) jitted_train_step = jit(train_step, static_argnames=("config", "optimizer"), donate_argnums=(1, 3)) jitted_eval_step = jit(eval_step, static_argnames=("config",)) logger.msg("Determining all unique training shapes...") train_shapes = { _get_shape_for_step(s, config) for s in range(config.n_train_iters) } val_config = dataclasses.replace(config, input_bin=config.input_val_bin) val_seq_len = val_config.max_sequence_length total_tokens = val_config.batch_size * val_config.min_sequence_length val_B = ( total_tokens // val_seq_len // val_config.micro_batch_size ) * val_config.micro_batch_size val_B = max(val_B, val_config.micro_batch_size) val_n_grad_acc = val_B // val_config.micro_batch_size val_shape = (val_seq_len, val_B, val_n_grad_acc) train_shapes.add(val_shape) logger.msg("Starting Ahead-of-Time (AOT) compilation for all shapes...") compiled_train_steps = {} compiled_eval_fn = None activation_sharding = NamedSharding(mesh, P(*config.activation_sharding)) for seq_len, batch_size, n_grad_acc_steps in sorted(list(train_shapes)): shape_key = (seq_len, batch_size, n_grad_acc_steps) logger.msg( f"AOT compiling for seq_len={seq_len}, B={batch_size}, grad_acc={n_grad_acc_steps}..." ) dummy_x_shape = (n_grad_acc_steps, config.micro_batch_size, seq_len) dummy_x = jnp.zeros(dummy_x_shape, dtype=jnp.int32) dummy_y = jnp.zeros_like(dummy_x) dummy_x = jax.device_put(dummy_x, activation_sharding) dummy_y = jax.device_put(dummy_y, activation_sharding) compiled_fn = jitted_train_step.lower( config, params, precomputed_params, opt_state, optimizer, dummy_x, dummy_y ).compile() compiled_train_steps[shape_key] = compiled_fn if shape_key == val_shape: compiled_eval_fn = jitted_eval_step.lower( params, dummy_x, dummy_y, precomputed_params, config ).compile() logger.msg("AOT compilation finished for all function variants.") logger.msg("Pre-computing and loading all training batches...") train_batches = load_dataset(config, logger, mesh, is_training=True) train_loader = iter(train_batches) logger.msg(f"Loaded {len(train_batches)} training batches for this process.") logger.msg("Pre-computing and loading all validation batches...") val_batches = load_dataset(val_config, logger, mesh, is_training=False) logger.msg(f"Loaded {len(val_batches)} validation batches for this process.") logger.msg("Starting training...") for step in range(config.n_train_iters): batched_x, batched_y = next(train_loader) n_grad_acc, _, seq_len = batched_x.shape batch_size = n_grad_acc * config.micro_batch_size current_shape_key = (seq_len, batch_size, n_grad_acc) aot_train_fn = compiled_train_steps[current_shape_key] params, opt_state, metrics = aot_train_fn( params, precomputed_params, opt_state, batched_x, batched_y, ) if step % 10 == 9: logger.log({"step": step, "time": datetime.datetime.now()} | metrics) if step > 0 and (step % config.val_loss_every == 0): run_evaluation( step, config, params, iter(val_batches), precomputed_params, mesh, logger, compiled_eval_fn ) if config.save_every > 0 and step > 0 and (step % config.save_every == 0): logger.dump(step, params, opt_state, config) logger.flush() logger.msg("Final validation") run_evaluation( step, config, params, iter(val_batches), precomputed_params, mesh, logger, compiled_eval_fn ) logger.flush() logger.msg("Training finished.") logger.dump(step, params, opt_state, config) if __name__ == "__main__": config = Config() train_loop(config) ==================================================================================================== [MESSAGE] Determining all unique training shapes... [MESSAGE] Starting Ahead-of-Time (AOT) compilation for all shapes... [MESSAGE] AOT compiling for seq_len=1024, B=512, grad_acc=32... [MESSAGE] AOT compiling for seq_len=2048, B=256, grad_acc=16... [MESSAGE] AOT compilation finished for all function variants. [MESSAGE] Pre-computing and loading all training batches... [MESSAGE] Process 0/1 starting data pre-loading into RAM... [MESSAGE] Process 0/1 finished loading 6.00 GB of tokens. [MESSAGE] Process 0/1 pre-computed 1675 batches. [MESSAGE] Loaded 1675 training batches for this process. [MESSAGE] Pre-computing and loading all validation batches... [MESSAGE] Process 0/1 starting data pre-loading into RAM... [MESSAGE] Process 0/1 finished loading 0.20 GB of tokens. [MESSAGE] Process 0/1 pre-computed 20 batches. [MESSAGE] Loaded 20 validation batches for this process. [MESSAGE] Starting training... [METRICS (1 step stale)] step: 9 | time: 2025-08-20 13:50:50.485907 | loss: 6.471743583679199 [METRICS (1 step stale)] step: 19 | time: 2025-08-20 13:50:50.516435 | loss: 6.047224998474121 [METRICS (1 step stale)] step: 29 | time: 2025-08-20 13:50:53.609852 | loss: 5.782470703125 [METRICS (1 step stale)] step: 39 | time: 2025-08-20 13:50:56.683170 | loss: 5.512787342071533 [METRICS (1 step stale)] step: 49 | time: 2025-08-20 13:50:59.756620 | loss: 5.361858367919922 [METRICS (1 step stale)] step: 59 | time: 2025-08-20 13:51:02.829865 | loss: 5.1754302978515625 [METRICS (1 step stale)] step: 69 | time: 2025-08-20 13:51:05.902741 | loss: 5.034331321716309 [METRICS (1 step stale)] step: 79 | time: 2025-08-20 13:51:08.976037 | loss: 4.853498935699463 [METRICS (1 step stale)] step: 89 | time: 2025-08-20 13:51:12.048553 | loss: 4.690052509307861 [METRICS (1 step stale)] step: 99 | time: 2025-08-20 13:51:15.121923 | loss: 4.592535495758057 [METRICS (1 step stale)] step: 109 | time: 2025-08-20 13:51:18.194632 | loss: 4.586787223815918 [MESSAGE] Running validation for step 125... [METRICS (1 step stale)] step: 119 | time: 2025-08-20 13:51:21.267343 | loss: 4.439126491546631 [MESSAGE] Validation finished for step 125. [METRICS (1 step stale)] step: 125 | val_loss: 4.378868579864502 [METRICS (1 step stale)] step: 129 | time: 2025-08-20 13:51:29.866469 | loss: 4.3136305809021 [METRICS (1 step stale)] step: 139 | time: 2025-08-20 13:51:31.083749 | loss: 4.403034210205078 [METRICS (1 step stale)] step: 149 | time: 2025-08-20 13:51:32.312318 | loss: 4.317258358001709 [METRICS (1 step stale)] step: 159 | time: 2025-08-20 13:51:35.383015 | loss: 4.266417503356934 [METRICS (1 step stale)] step: 169 | time: 2025-08-20 13:51:38.454729 | loss: 4.174162864685059 [METRICS (1 step stale)] step: 179 | time: 2025-08-20 13:51:41.524670 | loss: 4.178581714630127 [METRICS (1 step stale)] step: 189 | time: 2025-08-20 13:51:44.595092 | loss: 4.1373796463012695 [METRICS (1 step stale)] step: 199 | time: 2025-08-20 13:51:47.665457 | loss: 4.1474151611328125 [METRICS (1 step stale)] step: 209 | time: 2025-08-20 13:51:50.735650 | loss: 4.031351566314697 [METRICS (1 step stale)] step: 219 | time: 2025-08-20 13:51:53.806046 | loss: 4.0923895835876465 [METRICS (1 step stale)] step: 229 | time: 2025-08-20 13:51:56.876283 | loss: 4.0985612869262695 [METRICS (1 step stale)] step: 239 | time: 2025-08-20 13:51:59.946486 | loss: 4.002068996429443 [MESSAGE] Running validation for step 250... [METRICS (1 step stale)] step: 249 | time: 2025-08-20 13:52:03.016231 | loss: 4.068719387054443 [MESSAGE] Validation finished for step 250. [METRICS (1 step stale)] step: 250 | val_loss: 3.9816060066223145 [METRICS (1 step stale)] step: 259 | time: 2025-08-20 13:52:10.259963 | loss: 3.9959166049957275 [METRICS (1 step stale)] step: 269 | time: 2025-08-20 13:52:11.290507 | loss: 4.009436130523682 [METRICS (1 step stale)] step: 279 | time: 2025-08-20 13:52:14.054369 | loss: 3.9717233180999756 [METRICS (1 step stale)] step: 289 | time: 2025-08-20 13:52:17.124270 | loss: 3.9549062252044678 [METRICS (1 step stale)] step: 299 | time: 2025-08-20 13:52:20.193791 | loss: 3.95847749710083 [METRICS (1 step stale)] step: 309 | time: 2025-08-20 13:52:23.263515 | loss: 3.9942121505737305 [METRICS (1 step stale)] step: 319 | time: 2025-08-20 13:52:26.333754 | loss: 3.99407958984375 [METRICS (1 step stale)] step: 329 | time: 2025-08-20 13:52:29.403493 | loss: 3.952791213989258 [METRICS (1 step stale)] step: 339 | time: 2025-08-20 13:52:32.473394 | loss: 3.8804521560668945 [METRICS (1 step stale)] step: 349 | time: 2025-08-20 13:52:35.543009 | loss: 3.8041751384735107 [METRICS (1 step stale)] step: 359 | time: 2025-08-20 13:52:38.613007 | loss: 3.7829675674438477 [MESSAGE] Running validation for step 375... [METRICS (1 step stale)] step: 369 | time: 2025-08-20 13:52:41.681485 | loss: 3.8893344402313232 [MESSAGE] Validation finished for step 375. [METRICS (1 step stale)] step: 375 | val_loss: 3.8263344764709473 [METRICS (1 step stale)] step: 379 | time: 2025-08-20 13:52:50.273122 | loss: 3.8724312782287598 [METRICS (1 step stale)] step: 389 | time: 2025-08-20 13:52:51.489860 | loss: 3.8755533695220947 [METRICS (1 step stale)] step: 399 | time: 2025-08-20 13:52:52.716833 | loss: 3.859685182571411 [METRICS (1 step stale)] step: 409 | time: 2025-08-20 13:52:55.786368 | loss: 3.8463704586029053 [METRICS (1 step stale)] step: 419 | time: 2025-08-20 13:52:58.854976 | loss: 3.8484466075897217 [METRICS (1 step stale)] step: 429 | time: 2025-08-20 13:53:01.924491 | loss: 3.7746026515960693 [METRICS (1 step stale)] step: 439 | time: 2025-08-20 13:53:04.994403 | loss: 3.8191089630126953 [METRICS (1 step stale)] step: 449 | time: 2025-08-20 13:53:08.063666 | loss: 3.818345546722412 [METRICS (1 step stale)] step: 459 | time: 2025-08-20 13:53:11.133400 | loss: 3.7937140464782715 [METRICS (1 step stale)] step: 469 | time: 2025-08-20 13:53:14.203035 | loss: 3.801438570022583 [METRICS (1 step stale)] step: 479 | time: 2025-08-20 13:53:17.272227 | loss: 3.7843170166015625 [METRICS (1 step stale)] step: 489 | time: 2025-08-20 13:53:20.341485 | loss: 3.7632782459259033 [MESSAGE] Running validation for step 500... [METRICS (1 step stale)] step: 499 | time: 2025-08-20 13:53:23.410766 | loss: 3.790884256362915 [MESSAGE] Validation finished for step 500. [METRICS (1 step stale)] step: 500 | val_loss: 3.7325398921966553 [METRICS (1 step stale)] step: 509 | time: 2025-08-20 13:53:30.653156 | loss: 3.772469997406006 [METRICS (1 step stale)] step: 519 | time: 2025-08-20 13:53:31.683826 | loss: 3.7651329040527344 [METRICS (1 step stale)] step: 529 | time: 2025-08-20 13:53:34.445748 | loss: 3.792227268218994 [METRICS (1 step stale)] step: 539 | time: 2025-08-20 13:53:37.515855 | loss: 3.8172860145568848 [METRICS (1 step stale)] step: 549 | time: 2025-08-20 13:53:40.585061 | loss: 3.6823465824127197 [METRICS (1 step stale)] step: 559 | time: 2025-08-20 13:53:43.654424 | loss: 3.7520127296447754 [METRICS (1 step stale)] step: 569 | time: 2025-08-20 13:53:46.723531 | loss: 3.747046709060669 [METRICS (1 step stale)] step: 579 | time: 2025-08-20 13:53:49.792985 | loss: 3.7283899784088135 [METRICS (1 step stale)] step: 589 | time: 2025-08-20 13:53:52.861550 | loss: 3.7660343647003174 [METRICS (1 step stale)] step: 599 | time: 2025-08-20 13:53:55.930789 | loss: 3.7520663738250732 [METRICS (1 step stale)] step: 609 | time: 2025-08-20 13:53:59.000146 | loss: 3.697028636932373 [MESSAGE] Running validation for step 625... [METRICS (1 step stale)] step: 619 | time: 2025-08-20 13:54:02.069577 | loss: 3.7264294624328613 [MESSAGE] Validation finished for step 625. [METRICS (1 step stale)] step: 625 | val_loss: 3.649327039718628 [METRICS (1 step stale)] step: 629 | time: 2025-08-20 13:54:10.659129 | loss: 3.660844564437866 [METRICS (1 step stale)] step: 639 | time: 2025-08-20 13:54:11.875259 | loss: 3.7046656608581543 [METRICS (1 step stale)] step: 649 | time: 2025-08-20 13:54:13.102835 | loss: 3.7377099990844727 [METRICS (1 step stale)] step: 659 | time: 2025-08-20 13:54:16.171866 | loss: 3.7365329265594482 [METRICS (1 step stale)] step: 669 | time: 2025-08-20 13:54:19.240802 | loss: 3.6864848136901855 [METRICS (1 step stale)] step: 679 | time: 2025-08-20 13:54:22.309079 | loss: 3.6769893169403076 [METRICS (1 step stale)] step: 689 | time: 2025-08-20 13:54:25.378479 | loss: 3.608036518096924 [METRICS (1 step stale)] step: 699 | time: 2025-08-20 13:54:28.447790 | loss: 3.6322433948516846 [METRICS (1 step stale)] step: 709 | time: 2025-08-20 13:54:31.516965 | loss: 3.6722989082336426 [METRICS (1 step stale)] step: 719 | time: 2025-08-20 13:54:34.584991 | loss: 3.667214870452881 [METRICS (1 step stale)] step: 729 | time: 2025-08-20 13:54:37.654068 | loss: 3.652181386947632 [METRICS (1 step stale)] step: 739 | time: 2025-08-20 13:54:40.722094 | loss: 3.6320765018463135 [MESSAGE] Running validation for step 750... [METRICS (1 step stale)] step: 749 | time: 2025-08-20 13:54:43.792237 | loss: 3.6550650596618652 [MESSAGE] Validation finished for step 750. [METRICS (1 step stale)] step: 750 | val_loss: 3.5913608074188232 [METRICS (1 step stale)] step: 759 | time: 2025-08-20 13:54:51.034807 | loss: 3.6236650943756104 [METRICS (1 step stale)] step: 769 | time: 2025-08-20 13:54:52.064710 | loss: 3.660318613052368 [METRICS (1 step stale)] step: 779 | time: 2025-08-20 13:54:54.827199 | loss: 3.630045175552368 [METRICS (1 step stale)] step: 789 | time: 2025-08-20 13:54:57.896864 | loss: 3.5981388092041016 [METRICS (1 step stale)] step: 799 | time: 2025-08-20 13:55:00.965673 | loss: 3.6450397968292236 [METRICS (1 step stale)] step: 809 | time: 2025-08-20 13:55:04.034738 | loss: 3.604555130004883 [METRICS (1 step stale)] step: 819 | time: 2025-08-20 13:55:07.103554 | loss: 3.6314570903778076 [METRICS (1 step stale)] step: 829 | time: 2025-08-20 13:55:10.172477 | loss: 3.5898005962371826 [METRICS (1 step stale)] step: 839 | time: 2025-08-20 13:55:13.974459 | loss: 3.5569725036621094 [METRICS (1 step stale)] step: 849 | time: 2025-08-20 13:55:16.310128 | loss: 3.560964584350586 [METRICS (1 step stale)] step: 859 | time: 2025-08-20 13:55:19.612542 | loss: 3.501368522644043 [MESSAGE] Running validation for step 875... [METRICS (1 step stale)] step: 869 | time: 2025-08-20 13:55:23.844602 | loss: 3.5661191940307617 [MESSAGE] Validation finished for step 875. [METRICS (1 step stale)] step: 875 | val_loss: 3.5382885932922363 [METRICS (1 step stale)] step: 879 | time: 2025-08-20 13:55:35.469705 | loss: 3.5392282009124756 [METRICS (1 step stale)] step: 889 | time: 2025-08-20 13:55:36.684914 | loss: 3.6375319957733154 [METRICS (1 step stale)] step: 899 | time: 2025-08-20 13:55:38.377884 | loss: 3.5591230392456055 [METRICS (1 step stale)] step: 909 | time: 2025-08-20 13:55:42.611677 | loss: 3.511077880859375 [METRICS (1 step stale)] step: 919 | time: 2025-08-20 13:55:46.843374 | loss: 3.560610771179199 [METRICS (1 step stale)] step: 929 | time: 2025-08-20 13:55:51.076039 | loss: 3.51541805267334 [METRICS (1 step stale)] step: 939 | time: 2025-08-20 13:55:55.306844 | loss: 3.4899306297302246 [METRICS (1 step stale)] step: 949 | time: 2025-08-20 13:55:59.540829 | loss: 3.5559613704681396 [METRICS (1 step stale)] step: 959 | time: 2025-08-20 13:56:03.773436 | loss: 3.5266826152801514 [METRICS (1 step stale)] step: 969 | time: 2025-08-20 13:56:08.007206 | loss: 3.518746852874756 [METRICS (1 step stale)] step: 979 | time: 2025-08-20 13:56:12.242955 | loss: 3.5214052200317383 [METRICS (1 step stale)] step: 989 | time: 2025-08-20 13:56:16.474534 | loss: 3.5675082206726074 [MESSAGE] Running validation for step 1000... [METRICS (1 step stale)] step: 999 | time: 2025-08-20 13:56:20.709770 | loss: 3.542637348175049 [MESSAGE] Validation finished for step 1000. [METRICS (1 step stale)] step: 1000 | val_loss: 3.497018814086914 [METRICS (1 step stale)] step: 1009 | time: 2025-08-20 13:56:30.394738 | loss: 3.560560464859009 [METRICS (1 step stale)] step: 1019 | time: 2025-08-20 13:56:31.424571 | loss: 3.4775853157043457 [METRICS (1 step stale)] step: 1029 | time: 2025-08-20 13:56:35.234030 | loss: 3.497446060180664 [METRICS (1 step stale)] step: 1039 | time: 2025-08-20 13:56:39.466425 | loss: 3.7506730556488037 [METRICS (1 step stale)] step: 1049 | time: 2025-08-20 13:56:43.699382 | loss: 3.553964138031006 [METRICS (1 step stale)] step: 1059 | time: 2025-08-20 13:56:47.932835 | loss: 3.463603973388672 [METRICS (1 step stale)] step: 1069 | time: 2025-08-20 13:56:52.164625 | loss: 3.4607086181640625 [METRICS (1 step stale)] step: 1079 | time: 2025-08-20 13:56:56.404581 | loss: 3.5184004306793213 [METRICS (1 step stale)] step: 1089 | time: 2025-08-20 13:57:00.638680 | loss: 3.475104570388794 [METRICS (1 step stale)] step: 1099 | time: 2025-08-20 13:57:04.872200 | loss: 3.4671080112457275 [METRICS (1 step stale)] step: 1109 | time: 2025-08-20 13:57:09.101193 | loss: 3.5727367401123047 [MESSAGE] Running validation for step 1125... [METRICS (1 step stale)] step: 1119 | time: 2025-08-20 13:57:13.337652 | loss: 3.5693674087524414 [MESSAGE] Validation finished for step 1125. [METRICS (1 step stale)] step: 1125 | val_loss: 3.448789358139038 [METRICS (1 step stale)] step: 1129 | time: 2025-08-20 13:57:24.955096 | loss: 3.4653639793395996 [METRICS (1 step stale)] step: 1139 | time: 2025-08-20 13:57:26.171979 | loss: 3.4399702548980713 [METRICS (1 step stale)] step: 1149 | time: 2025-08-20 13:57:27.864129 | loss: 3.558744430541992 [METRICS (1 step stale)] step: 1159 | time: 2025-08-20 13:57:32.100953 | loss: 3.5004079341888428 [METRICS (1 step stale)] step: 1169 | time: 2025-08-20 13:57:36.332562 | loss: 3.474762439727783 [METRICS (1 step stale)] step: 1179 | time: 2025-08-20 13:57:40.567678 | loss: 3.4206159114837646 [METRICS (1 step stale)] step: 1189 | time: 2025-08-20 13:57:44.799589 | loss: 3.4221444129943848 [METRICS (1 step stale)] step: 1199 | time: 2025-08-20 13:57:49.032921 | loss: 3.409047842025757 [METRICS (1 step stale)] step: 1209 | time: 2025-08-20 13:57:53.267073 | loss: 3.446755886077881 [METRICS (1 step stale)] step: 1219 | time: 2025-08-20 13:57:57.499798 | loss: 3.4224650859832764 [METRICS (1 step stale)] step: 1229 | time: 2025-08-20 13:58:01.735303 | loss: 3.3683292865753174 [METRICS (1 step stale)] step: 1239 | time: 2025-08-20 13:58:05.969270 | loss: 3.442624807357788 [MESSAGE] Running validation for step 1250... [METRICS (1 step stale)] step: 1249 | time: 2025-08-20 13:58:10.203225 | loss: 3.367388963699341 [MESSAGE] Validation finished for step 1250. [METRICS (1 step stale)] step: 1250 | val_loss: 3.397667646408081 [METRICS (1 step stale)] step: 1259 | time: 2025-08-20 13:58:19.896055 | loss: 3.4043309688568115 [METRICS (1 step stale)] step: 1269 | time: 2025-08-20 13:58:20.924574 | loss: 3.3819515705108643 [METRICS (1 step stale)] step: 1279 | time: 2025-08-20 13:58:24.737422 | loss: 3.3692402839660645 [METRICS (1 step stale)] step: 1289 | time: 2025-08-20 13:58:28.971952 | loss: 3.447669267654419 [METRICS (1 step stale)] step: 1299 | time: 2025-08-20 13:58:33.209369 | loss: 3.479482412338257 [METRICS (1 step stale)] step: 1309 | time: 2025-08-20 13:58:37.444948 | loss: 3.359408140182495 [METRICS (1 step stale)] step: 1319 | time: 2025-08-20 13:58:41.681583 | loss: 3.440824270248413 [METRICS (1 step stale)] step: 1329 | time: 2025-08-20 13:58:45.916973 | loss: 3.338726758956909 [METRICS (1 step stale)] step: 1339 | time: 2025-08-20 13:58:50.152896 | loss: 3.4674599170684814 [METRICS (1 step stale)] step: 1349 | time: 2025-08-20 13:58:54.383599 | loss: 3.3167357444763184 [METRICS (1 step stale)] step: 1359 | time: 2025-08-20 13:58:58.616920 | loss: 3.3100430965423584 [MESSAGE] Running validation for step 1375... [METRICS (1 step stale)] step: 1369 | time: 2025-08-20 13:59:02.851478 | loss: 3.362697124481201 [MESSAGE] Validation finished for step 1375. [METRICS (1 step stale)] step: 1375 | val_loss: 3.35252046585083 [METRICS (1 step stale)] step: 1379 | time: 2025-08-20 13:59:14.478478 | loss: 3.396247625350952 [METRICS (1 step stale)] step: 1389 | time: 2025-08-20 13:59:15.692602 | loss: 3.3447999954223633 [METRICS (1 step stale)] step: 1399 | time: 2025-08-20 13:59:17.387808 | loss: 3.389324188232422 [METRICS (1 step stale)] step: 1409 | time: 2025-08-20 13:59:21.624988 | loss: 3.3342318534851074 [METRICS (1 step stale)] step: 1419 | time: 2025-08-20 13:59:25.859846 | loss: 3.3981213569641113 [METRICS (1 step stale)] step: 1429 | time: 2025-08-20 13:59:30.091816 | loss: 3.3394150733947754 [METRICS (1 step stale)] step: 1439 | time: 2025-08-20 13:59:34.326004 | loss: 3.3047027587890625 [METRICS (1 step stale)] step: 1449 | time: 2025-08-20 13:59:38.562969 | loss: 3.317003011703491 [METRICS (1 step stale)] step: 1459 | time: 2025-08-20 13:59:42.798248 | loss: 3.375948190689087 [METRICS (1 step stale)] step: 1469 | time: 2025-08-20 13:59:47.031706 | loss: 3.319119453430176 [METRICS (1 step stale)] step: 1479 | time: 2025-08-20 13:59:51.266632 | loss: 3.3129196166992188 [METRICS (1 step stale)] step: 1489 | time: 2025-08-20 13:59:55.496505 | loss: 3.2776339054107666 [MESSAGE] Running validation for step 1500... [METRICS (1 step stale)] step: 1499 | time: 2025-08-20 13:59:59.731467 | loss: 3.245753288269043 [MESSAGE] Validation finished for step 1500. [METRICS (1 step stale)] step: 1500 | val_loss: 3.3116257190704346 [METRICS (1 step stale)] step: 1509 | time: 2025-08-20 14:00:09.419534 | loss: 3.2631149291992188 [METRICS (1 step stale)] step: 1519 | time: 2025-08-20 14:00:10.448179 | loss: 3.3242125511169434 [METRICS (1 step stale)] step: 1529 | time: 2025-08-20 14:00:14.257816 | loss: 3.3455348014831543 [METRICS (1 step stale)] step: 1539 | time: 2025-08-20 14:00:18.490213 | loss: 3.312182664871216 [METRICS (1 step stale)] step: 1549 | time: 2025-08-20 14:00:22.721110 | loss: 3.277139663696289 [METRICS (1 step stale)] step: 1559 | time: 2025-08-20 14:00:26.955341 | loss: 3.30891752243042 [METRICS (1 step stale)] step: 1569 | time: 2025-08-20 14:00:31.193679 | loss: 3.2312824726104736 [METRICS (1 step stale)] step: 1579 | time: 2025-08-20 14:00:35.422963 | loss: 3.299146890640259 [METRICS (1 step stale)] step: 1589 | time: 2025-08-20 14:00:39.652532 | loss: 3.3006279468536377 [METRICS (1 step stale)] step: 1599 | time: 2025-08-20 14:00:43.884421 | loss: 3.3266823291778564 [METRICS (1 step stale)] step: 1609 | time: 2025-08-20 14:00:48.115959 | loss: 3.2253568172454834 [MESSAGE] Running validation for step 1625... [METRICS (1 step stale)] step: 1619 | time: 2025-08-20 14:00:52.352488 | loss: 3.4040253162384033 [MESSAGE] Validation finished for step 1625. [METRICS (1 step stale)] step: 1625 | val_loss: 3.2820873260498047 [METRICS (1 step stale)] step: 1629 | time: 2025-08-20 14:01:03.965982 | loss: 3.2963802814483643 [METRICS (1 step stale)] step: 1639 | time: 2025-08-20 14:01:05.180357 | loss: 3.3082056045532227 [METRICS (1 step stale)] step: 1649 | time: 2025-08-20 14:01:06.876225 | loss: 3.264956474304199 [METRICS (1 step stale)] step: 1659 | time: 2025-08-20 14:01:11.107810 | loss: 3.2650082111358643 [METRICS (latest)] step: 1669 | time: 2025-08-20 14:01:15.339523 | loss: 3.258247137069702 [MESSAGE] Final validation [MESSAGE] Running validation for step 1674... [MESSAGE] Validation finished for step 1674. [METRICS (latest)] step: 1674 | val_loss: 3.275080919265747 [MESSAGE] Training finished. [MESSAGE] Saved checkpoint to logs/7f8f89f1-d8f9-462f-98f6-1027ed9b7159//state_step001674.pkl