• Home
  • About Us
  • Contact Us
  • Disclaimer
  • Privacy Policy
Friday, June 13, 2025
newsaiworld
  • Home
  • Artificial Intelligence
  • ChatGPT
  • Data Science
  • Machine Learning
  • Crypto Coins
  • Contact Us
No Result
View All Result
  • Home
  • Artificial Intelligence
  • ChatGPT
  • Data Science
  • Machine Learning
  • Crypto Coins
  • Contact Us
No Result
View All Result
Morning News
No Result
View All Result
Home Artificial Intelligence

Debugging the Dreaded NaN | In direction of Information Science

Admin by Admin
March 2, 2025
in Artificial Intelligence
0
0 Gtqvzbdclduzmulj Scaled.webp.webp
0
SHARES
0
VIEWS
Share on FacebookShare on Twitter

READ ALSO

Connecting the Dots for Higher Film Suggestions

Consumer Authorisation in Streamlit With OIDC and Google


You might be coaching your newest AI mannequin, anxiously watching because the loss steadily decreases when all of a sudden — increase! Your logs are flooded with NaNs (Not a Quantity) — your mannequin is irreparably corrupted and also you’re left watching your display in despair. To make issues worse, the NaNs don’t seem persistently. Typically your mannequin trains simply effective; different instances, it fails inexplicably. Typically it should crash instantly, generally after many days of coaching.

NaNs in Deep Studying workloads are amongst essentially the most irritating points to come across. And since they typically seem sporadically — triggered by a selected mixture of mannequin state, enter information, and stochastic components — they are often extremely troublesome to breed and debug.

Given the appreciable value of coaching AI fashions and the potential waste brought on by NaN failures, it’s endorsed to have devoted instruments for capturing and analyzing NaN occurrences. In a earlier publish, we mentioned the problem of debugging NaNs in a TensorFlow coaching workload. We proposed an environment friendly scheme for capturing and reproducing NaNs and shared a pattern TensorFlow implementation. On this publish, we undertake and display an identical mechanism for debugging NaNs in PyTorch workloads. The final scheme is as follows:

On every coaching step:

  1. Save a replica of the coaching enter batch.
  2. Test the gradients for NaN values. If any seem, save a checkpoint with the present mannequin weights earlier than the mannequin is corrupted. Additionally, save the enter batch and, if essential, the stochastic state. Discontinue the coaching job.
  3. Reproduce and debug the NaN prevalence by loading the saved experiment state.

Though this scheme might be simply carried out in native PyTorch, we are going to take the chance to display among the conveniences of PyTorch Lightning — a strong open-source framework designed to streamline the event of machine studying (ML) fashions. Constructed on PyTorch, Lightning abstracts away lots of the boiler-plate parts of an ML experiment, resembling coaching loops, information distribution, logging, and extra, enabling builders to concentrate on the core logic of their fashions.

To implement our NaN capturing scheme, we are going to use Lightning’s callback interface — a devoted construction that permits inserting customized logic at particular factors in the course of the circulate of execution.

Importantly, please don’t view our alternative of Lightning or every other software or approach that we point out as an endorsement of its use. The code that we’ll share is meant for demonstrative functions — please don’t depend on its correctness or optimality.

Many due to Rom Maltser for his contributions to this publish.

NaNCapture Callback

To implement our NaN capturing answer, we create a NaNCapture Lightning callback. The constructor receives a listing path for storing/loading checkpoints and units up the NaNCapture state. We additionally outline utilities for checking for NaNs, storing checkpoints, and halting the coaching job.

 import os
import torch
from copy import deepcopy
import lightning.pytorch as pl

class NaNCapture(pl.Callback):

    def __init__(self, dirpath: str):
        # path to checkpoint
        self.dirpath = dirpath
        
        # replace to True when Nan is recognized
        self.nan_captured = False
        
        # shops a replica of the final batch
        self.last_batch = None
        self.batch_idx = None

    @staticmethod
    def contains_nan(tensor):
        return torch.isnan(tensor).any().merchandise()
        # alternatively verify for finite
        # return not torch.isfinite(tensor).merchandise()

    @staticmethod
    def halt_training(coach):
        coach.should_stop = True
        # talk cease command to all different ranks
        coach.technique.reduce_boolean_decision(coach.should_stop,
                                                 all=False)

    def save_ckpt(self, coach):
        os.makedirs(self.dirpath, exist_ok=True)
        # embrace coach.global_rank to keep away from battle
        filename = f"nan_checkpoint_rank_{coach.global_rank}.ckpt"
        full_path = os.path.be part of(self.dirpath, filename)
        print(f"saving ckpt to {full_path}")
        coach.save_checkpoint(full_path, False)

Callback Operate: on_train_batch_start

We start by implementing the on_train_batch_start hook to retailer a replica of every enter batch. In case of a NaN occasion, this batch can be saved within the checkpoint.

Callback Operate: on_before_optimizer_step

Subsequent we implement the on_before_optimizer_step hook. Right here, we verify for NaN entries in the entire gradient tensors. If discovered, we retailer a checkpoint with the uncorrupted mannequin weights and halt the coaching.

Python">    def on_before_optimizer_step(self, coach, pl_module, optimizer):
        if not self.nan_captured:
            # Test if gradients include NaN
            grads = [p.grad.view(-1) for p in pl_module.parameters()
                     if p.grad is not None]
            all_grads = torch.cat(grads)
            if self.contains_nan(all_grads):
                print("nan discovered")
                self.save_ckpt(coach)
                self.halt_training(coach)

Capturing the Coaching State

To allow reproducibility, we embrace the NaNCapture state within the checkpoint by appending it to the coaching state dictionary. Lightning gives devoted utilities for saving and loading a callback state:

def state_dict(self):
        d = {"nan_captured": self.nan_captured}
        if self.nan_captured:
            d["last_batch"] = self.last_batch
        return d


    def load_state_dict(self, state_dict):
        self.nan_captured = state_dict.get("nan_captured", False)
        if self.nan_captured:
            self.last_batch = state_dict["last_batch"]

Reproducing the NaN Incidence

We’ve described how our NaNCapture callback can be utilized to retailer the coaching state that resulted in a NaN, however how will we reload this state as a way to reproduce the problem and debug it? To perform this, we leverage Lightning’s devoted information loading class, LightningDataModule.

DataModule Operate: on_before_batch_transfer

Within the code block under, we lengthen the LightningDataModule class to permit injecting a hard and fast coaching enter batch. That is achieved by overriding the on_before_batch_transfer hook, as proven under:

from lightning.pytorch import LightningDataModule

class InjectableDataModule(LightningDataModule):

    def __init__(self):
        tremendous().__init__()
        self.cached_batch = None

    def set_custom_batch(self, batch):
        self.cached_batch = batch

    def on_before_batch_transfer(self, batch, dataloader_idx):
        if self.cached_batch:
            return self.cached_batch
        return batch

Callback Operate: on_train_start

The ultimate step is modifying the on_train_start hook of our NaNCapture callback to inject the saved coaching batch into the LightningDataModule.

    def on_train_start(self, coach, pl_module):
        if self.nan_captured:
            datamodule = coach.datamodule
            datamodule.set_custom_batch(self.last_batch)

Within the subsequent part we are going to display the end-to-end answer utilizing a toy instance.

Toy Instance

To check our new callback, we create a resnet50-based picture classification mannequin with a loss operate intentionally designed to set off NaN occurrences.

As a substitute of utilizing the usual CrossEntropy loss, we compute binary_cross_entropy_with_logits for every class independently and divide the end result by the variety of samples belonging to that class. Inevitably, we are going to encounter a batch during which a number of lessons are lacking, resulting in a divide-by-zero operation, leading to NaN values and corrupting the mannequin.

The implementation under follows Lightning’s introductory tutorial.

import lightning.pytorch as pl
import torch
import torchvision
import torch.nn.purposeful as F

num_classes = 20


# outline a lightning module
class ResnetModel(pl.LightningModule):
    def __init__(self):
        """Initializes a brand new occasion of the MNISTModel class."""
        tremendous().__init__()
        self.mannequin = torchvision.fashions.resnet50(num_classes=num_classes)

    def ahead(self, x):
        return self.mannequin(x)

    def training_step(self, batch, batch_nb):
        x, y = batch
        outputs = self(x)
        # uncomment for default loss
        # return F.cross_entropy(outputs, y)
        
        # calculate binary_cross_entropy for every class individually
        losses = []
        for c in vary(num_classes):
            rely = torch.count_nonzero(y==c)
            masked = torch.the place(y==c, 1., 0.)
            loss = F.binary_cross_entropy_with_logits(
                outputs[..., c],
                masked,
                discount='sum'
            )
            mean_loss = loss/rely # may lead to NaN
            losses.append(mean_loss)
        total_loss = torch.stack(losses).imply()
        return total_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

We outline an artificial dataset and encapsulate it in our InjectableDataModule class:

import os
import random
from torch.utils.information import Dataset, DataLoader

batch_size = 128
num_steps = 800

# A dataset with random photos and labels
class FakeDataset(Dataset):
    def __len__(self):
        return batch_size*num_steps

    def __getitem__(self, index):
        rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
        label = torch.tensor(random.randint(0, num_classes-1),
                             dtype=torch.int64)
        return rand_image, label



# outline a lightning datamodule
class FakeDataModule(InjectableDataModule):

    def train_dataloader(self):
        dataset = FakeDataset()
        return DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=os.cpu_count(),
            pin_memory=True
        )

Lastly, we initialize a Lightning Coach with our NaNCapture callback and name coach.match with our Lightning module and Lightning DataModule.

import time

if __name__ == "__main__":

    # Initialize a lightning module
    lit_module = ResnetModel()

    # Initialize a DataModule
    mnist_data = FakeDataModule()

    # Practice the mannequin
    ckpt_dir = "./ckpt_dir"
    coach = pl.Coach(
        max_epochs=1,
        callbacks=[NaNCapture(ckpt_dir)]
    )

    ckpt_path = None
    
    # verify is nan ckpt exists
    if os.path.isdir(ckpt_dir):

    # verify if nan ckpt exists
    if os.path.isdir(ckpt_dir):
        dir_contents = [os.path.join(ckpt_dir, f)
                        for f in os.listdir(ckpt_dir)]
        ckpts = [f for f in dir_contents
                 if os.path.isfile(f) and f.endswith('.ckpt')]
        if ckpts:
            ckpt_path = ckpts[0]

    t0 = time.perf_counter()
    coach.match(lit_module, mnist_data, ckpt_path=ckpt_path)
    print(f"complete runtime: {time.perf_counter() - t0}")

After quite a few coaching steps, a NaN occasion will happen. At this level a checkpoint is saved with the complete coaching state and the coaching is halted.

When the script is run once more the precise state that precipitated the NaN can be reloaded permitting us to simply reproduce the problem and debug its root trigger.

Efficiency Overhead

To evaluate the influence of our NaNCapture callback on runtime efficiency, we modified our experiment to make use of CrossEntropyLoss (to keep away from NaNs) and measured the typical throughput when operating with and with out NaNCapture callback. The experiments have been performed on an NVIDIA L40S GPU, with a PyTorch 2.5.1 Docker picture.

Overhead of NaNCapture Callback (by Creator)

For our toy mannequin, the NaNCapture callback provides a minimal 1.5% overhead to the runtime efficiency — a small worth to pay for the dear debugging capabilities it gives.

Naturally, the precise overhead will depend upon the specifics of the mannequin and runtime setting.

Easy methods to Deal with Stochasticity

The answer we’ve got described henceforth will achieve reproducing the coaching state offered that the mannequin doesn’t embrace any randomness. Nonetheless, introducing stochasticity into the mannequin definition is usually crucial for convergence. A typical instance of a stochastic layer is torch.nn.Dropout.

It’s possible you’ll discover that your NaN occasion depends upon the exact state of randomness when the failure occurred. Consequently, we want to improve our NaNCapture callback to seize and restore the random state on the level of failure. The random state is set by quite a few libraries. Within the code block under, we try to seize the complete state of randomness:

import os
import torch
import random
import numpy as np
from copy import deepcopy
import lightning.pytorch as pl

class NaNCapture(pl.Callback):

    def __init__(self, dirpath: str):
        # path to checkpoint
        self.dirpath = dirpath
        
        # replace to True when Nan is recognized
        self.nan_captured = False
        
        # shops a replica of the final batch
        self.last_batch = None
        self.batch_idx = None

        # rng state
        self.rng_state = {
            "torch": None,
            "torch_cuda": None,
            "numpy": None,
            "random": None
        }

    @staticmethod
    def contains_nan(tensor):
        return torch.isnan(tensor).any().merchandise()
        # alternatively verify for finite
        # return not torch.isfinite(tensor).merchandise()

    @staticmethod
    def halt_training(coach):
        coach.should_stop = True
        coach.technique.reduce_boolean_decision(coach.should_stop,
                                                 all=False)

    def save_ckpt(self, coach):
        os.makedirs(self.dirpath, exist_ok=True)
        # embrace coach.global_rank to keep away from battle
        filename = f"nan_checkpoint_rank_{coach.global_rank}.ckpt"
        full_path = os.path.be part of(self.dirpath, filename)
        print(f"saving ckpt to {full_path}")
        coach.save_checkpoint(full_path, False)

    def on_train_start(self, coach, pl_module):
        if self.nan_captured:
            # inject batch
            datamodule = coach.datamodule
            datamodule.set_custom_batch(self.last_batch)

    def on_train_batch_start(self, coach, pl_module, batch, batch_idx):
       if self.nan_captured:
            # restore random state
            torch.random.set_rng_state(self.rng_state["torch"])
            torch.cuda.set_rng_state_all(self.rng_state["torch_cuda"])
            np.random.set_state(self.rng_state["numpy"])
            random.setstate(self.rng_state["random"])
        else:
            # seize present batch
            self.last_batch= deepcopy(batch)
            self.batch_idx = batch_idx
    
            # seize present random state
            self.rng_state["torch"] = torch.random.get_rng_state()
            self.rng_state["torch_cuda"] = torch.cuda.get_rng_state_all()
            self.rng_state["numpy"] = np.random.get_state()
            self.rng_state["random"] = random.getstate()
    
    def on_before_optimizer_step(self, coach, pl_module, optimizer):
        if not self.nan_captured:
            # Test if gradients include NaN
            grads = [p.grad.view(-1) for p in pl_module.parameters()
                     if p.grad is not None]
            all_grads = torch.cat(grads)
            if self.contains_nan(all_grads):
                print("nan discovered")
                self.save_ckpt(coach)
                self.halt_training(coach)

    def state_dict(self):
        d = {"nan_captured": self.nan_captured}
        if self.nan_captured:
            d["last_batch"] = self.last_batch
            d["rng_state"] = self.rng_state
        return d

    def load_state_dict(self, state_dict):
        self.nan_captured = state_dict.get("nan_captured", False)
        if self.nan_captured:
            self.last_batch = state_dict["last_batch"]
            self.rng_state = state_dict["rng_state"]

Importantly, setting the random state could not assure full reproducibility. The GPU owes its energy to its huge parallelism. In some GPU operations, a number of threads could learn or write concurrently to the identical reminiscence areas leading to nondeterminism. PyTorch permits for some management over this through its use_deterministic_algorithms, however this may increasingly influence the runtime efficiency. Moreover, there’s a chance that the NaN occasion is not going to reproduced as soon as this configuration setting is modified. Please see the PyTorch documentation on reproducibility for extra particulars.

Abstract

Encountering NaN failures is without doubt one of the most discouraging occasions that may occur in machine studying improvement. These errors not solely waste worthwhile computation and improvement sources, however typically point out elementary points within the mannequin structure or experiment design. Because of their sporadic, generally elusive nature, debugging NaN failures is usually a nightmare.

This publish launched a proactive strategy for capturing and reproducing NaN errors utilizing a devoted Lightning callback. The answer we shared is a proposal which might be modified and prolonged on your particular use case.

Whereas this answer could not deal with each potential NaN situation, it considerably reduces debugging time when relevant, probably saving builders numerous hours of frustration and wasted effort.

Tags: DataDebuggingDreadedNaNScience

Related Posts

Chatgpt image jun 12 2025 04 53 14 pm 1024x683.png
Artificial Intelligence

Connecting the Dots for Higher Film Suggestions

June 13, 2025
Hal.png
Artificial Intelligence

Consumer Authorisation in Streamlit With OIDC and Google

June 12, 2025
Screenshot 2025 06 09 at 10.42.31 pm.png
Artificial Intelligence

Mannequin Context Protocol (MCP) Tutorial: Construct Your First MCP Server in 6 Steps

June 12, 2025
Audiomoth.webp.webp
Artificial Intelligence

Audio Spectrogram Transformers Past the Lab

June 11, 2025
1749574001 default image.jpg
Artificial Intelligence

Functions of Density Estimation to Authorized Principle

June 10, 2025
0 brlbtvg9haryy7 h.jpg
Artificial Intelligence

The best way to Transition From Knowledge Analyst to Knowledge Scientist

June 10, 2025
Next Post
Hno International Logo 2 1 0325.jpg

HNO Worldwide Changing Wasted Flared Fuel into Vitality for Information Facilities, Bitcoin Mining and Hydrogen

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

POPULAR NEWS

0 3.png

College endowments be a part of crypto rush, boosting meme cash like Meme Index

February 10, 2025
Gemini 2.0 Fash Vs Gpt 4o.webp.webp

Gemini 2.0 Flash vs GPT 4o: Which is Higher?

January 19, 2025
1da3lz S3h Cujupuolbtvw.png

Scaling Statistics: Incremental Customary Deviation in SQL with dbt | by Yuval Gorchover | Jan, 2025

January 2, 2025
How To Maintain Data Quality In The Supply Chain Feature.jpg

Find out how to Preserve Knowledge High quality within the Provide Chain

September 8, 2024
0khns0 Djocjfzxyr.jpeg

Constructing Data Graphs with LLM Graph Transformer | by Tomaz Bratanic | Nov, 2024

November 5, 2024

EDITOR'S PICK

1ks Kqc0strv9xgy Dljqgq.jpeg

Fingers-On Supply Routes Optimization (TSP) with AI, Utilizing LKH and Python | by Piero Paialunga | Jan, 2025

January 15, 2025
Fmps20crypto20bitcoin id 4f50f50e 0b00 46a6 b002 04b0145d5e39 size900.jpg

Previous Efficiency and Future Developments

August 17, 2024
Bala math roadmap.jpeg

How you can Be taught Math for Information Science: A Roadmap for Rookies

June 12, 2025
3070x1400.png

Protecting our platform protected: How we shield towards fraud and prison exercise

June 13, 2025

About Us

Welcome to News AI World, your go-to source for the latest in artificial intelligence news and developments. Our mission is to deliver comprehensive and insightful coverage of the rapidly evolving AI landscape, keeping you informed about breakthroughs, trends, and the transformative impact of AI technologies across industries.

Categories

  • Artificial Intelligence
  • ChatGPT
  • Crypto Coins
  • Data Science
  • Machine Learning

Recent Posts

  • Barbie maker Mattel indicators up with OpenAI • The Register
  • FedEx Deploys Hellebrekers Robotic Sorting Arm in Germany
  • ETH, XRP, ADA, SOL, and HYPE
  • Home
  • About Us
  • Contact Us
  • Disclaimer
  • Privacy Policy

© 2024 Newsaiworld.com. All rights reserved.

No Result
View All Result
  • Home
  • Artificial Intelligence
  • ChatGPT
  • Data Science
  • Machine Learning
  • Crypto Coins
  • Contact Us

© 2024 Newsaiworld.com. All rights reserved.

Are you sure want to unlock this post?
Unlock left : 0
Are you sure want to cancel subscription?