Source code for checkpoint
"""Uses `pickle` to save and restore populations (and other aspects of the simulation state)."""
import gzip
import pickle
import random
import time
from neat.population import Population
from neat.reporting import BaseReporter
[docs]
class Checkpointer(BaseReporter):
"""
A reporter class that performs checkpointing using `pickle`
to save and restore populations (and other aspects of the simulation state).
"""
def __init__(self, generation_interval, time_interval_seconds=None,
filename_prefix='neat-checkpoint-'):
"""
Saves the current state (at the end of a generation) every ``generation_interval`` generations or
``time_interval_seconds``, whichever happens first.
:param generation_interval: If not None, maximum number of generations between save intervals
:type generation_interval: int or None
:param time_interval_seconds: If not None, maximum number of seconds between checkpoint attempts
:type time_interval_seconds: float or None
:param str filename_prefix: Prefix for the filename (the end will be the generation number)
"""
self.generation_interval = generation_interval
self.time_interval_seconds = time_interval_seconds
self.filename_prefix = filename_prefix
self.current_generation = None
self.last_generation_checkpoint = -1
self.last_time_checkpoint = time.time()
def start_generation(self, generation):
self.current_generation = generation
def end_generation(self, config, population, species_set):
checkpoint_due = False
if self.time_interval_seconds is not None:
dt = time.time() - self.last_time_checkpoint
if dt >= self.time_interval_seconds:
checkpoint_due = True
if (checkpoint_due is False) and (self.generation_interval is not None):
dg = self.current_generation - self.last_generation_checkpoint
if dg >= self.generation_interval:
checkpoint_due = True
if checkpoint_due:
self.save_checkpoint(config, population, species_set, self.current_generation)
self.last_generation_checkpoint = self.current_generation
self.last_time_checkpoint = time.time()
[docs]
def save_checkpoint(self, config, population, species_set, generation):
"""
Save the current simulation state.
Note: This is called from Population via the reporter interface.
We need to access the innovation tracker from the Population's reproduction object.
However, since this is a reporter callback, we don't have direct access to Population.
The innovation tracker will be saved as part of the config state when needed.
"""
filename = '{0}{1}'.format(self.filename_prefix, generation)
print("Saving checkpoint to {0}".format(filename))
with gzip.open(filename, 'w', compresslevel=5) as f:
# Note: innovation_tracker is stored in config.genome_config.innovation_tracker
# and is automatically included via pickle
data = (generation, config, population, species_set, random.getstate())
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
[docs]
@staticmethod
def restore_checkpoint(filename, new_config=None):
"""
Resumes the simulation from a previous saved point.
The innovation tracker state is preserved in the pickled config and must be
transferred to the new reproduction object to ensure innovation numbers continue
correctly and prevent collisions during crossover.
"""
with gzip.open(filename) as f:
generation, saved_config, population, species_set, rndstate = pickle.load(f)
random.setstate(rndstate)
# Extract the saved innovation tracker from the config before replacing it
saved_innovation_tracker = None
if hasattr(saved_config.genome_config, 'innovation_tracker'):
saved_innovation_tracker = saved_config.genome_config.innovation_tracker
# Use new config if provided, otherwise use saved config
if new_config is not None:
config = new_config
else:
config = saved_config
# Create Population with restored state
# This creates a new reproduction object with a fresh innovation tracker
restored_pop = Population(config, (population, species_set, generation))
# Replace the fresh innovation tracker with the saved one to maintain
# the correct innovation numbering sequence
if saved_innovation_tracker is not None:
restored_pop.reproduction.innovation_tracker = saved_innovation_tracker
config.genome_config.innovation_tracker = saved_innovation_tracker
return restored_pop