Source code for lettuce.util.datautils

"""
datautils for writing/reading hdf5 files.
"""

import h5py
from torch.utils import data
from lettuce._version import get_versions
import pickle
import io
import numpy as np
from lettuce._simulation import Reporter

__all__ = ["HDF5Reporter",
           "LettuceDataset"]


[docs] class HDF5Reporter(Reporter): """ HDF5 _reporter for distribution function f in lettuce containing metadata of the simulation. Parameters ---------- filebase : string Path to the hdf5 file with annotations. metadata : dictionary Optional metadata can be saved. The passed values must be of type string. >>> metadata = {"attr_1": "str_value_1", "attr_2": "str_value_2"} interval : integer Define the step interval after the _reporter is applied. The _reporter will save f every "interval" step. Examples -------- Create a HDF5 _reporter. >>> import lettuce as lt >>> context = Context() >>> flow = lt.TaylorGreenVortex(context, [50, 50], 300, 0.1) >>> _collision = ... >>> simulation = ... >>> hdf5_reporter = lt.HDF5Reporter( >>> context=context, >>> flow=flow, >>> _collision=_collision, >>> interval= 100, >>> filebase="./h5_output") >>> simulation.reporters.append(hdf5_reporter) """ def __init__(self, flow, collision, interval, filebase='./output', metadata=None): self.context = flow.context self.interval = interval self.filebase = filebase fs = h5py.File(self.filebase + '.h5', 'w') fs.attrs['lettuce_version'] = get_versions()['version'] fs.attrs["flow"] = self._pickle_to_h5(flow) fs.attrs['_collision'] = self._pickle_to_h5(collision) if metadata: for attr in metadata: fs.attrs[attr] = metadata[attr] self.shape = (flow.stencil.q, *flow.grid[0].shape) fs.create_dataset(name="f", shape=(0, *self.shape), maxshape=(None, *self.shape)) fs.close() def __call__(self, simulation: 'Simulation'): # i, t, f): if simulation.flow.i % self.interval == 0: with h5py.File(self.filebase + '.h5', 'r+') as fs: fs["f"].resize(fs["f"].shape[0] + 1, axis=0) fs["f"][-1, ...] = self.context.convert_to_ndarray( simulation.flow.f) fs.attrs['data'] = str(fs["f"].shape[0]) fs.attrs['steps'] = str(simulation.flow.i) @staticmethod def _pickle_to_h5(instance): bytes_io = io.BytesIO() pickle.dump(instance, bytes_io) bytes_io.seek(0) return np.void(bytes_io.getvalue())
[docs] class LettuceDataset(data.Dataset): """ Custom dataset for HDF5 files in lettuce that can be used by torch's dataloader. Parameters ---------- filebase : string Path to the hdf5 file with annotations. transform : class object Optional transform to be applied on a f loaded from HDF5 file. target : logical operation (True, False) Returns also the next dataset[idx + skip_idx_to_target] - default=False skip_idx_to_target : integer Define which next target dataset is returned if target is True - default=1 Examples -------- Create a data loader. >>> import lettuce as lt >>> import torch >>> lattice = lt.Lattice(lt.D3Q27, device="cpu") >>> dataset_train = lt.LettuceDataset(lattice=lattice, >>> filebase= "./hdf5_output.h5", >>> target=True) >>> train_loader = torch.utils.data.DataLoader(dataset_train, shuffle=True) >>> for (f, target, idx) in train_loader: >>> ... """ def __init__(self, filebase, transform=None, target=False, skip_idx_to_target=1): super().__init__() self.filebase = filebase self.transform = transform self.target = target self.skip_idx_to_target = skip_idx_to_target self.fs = h5py.File(self.filebase, "r") self.shape = self.fs["f"].shape self.keys = list(self.fs.keys()) self.context = self._unpickle_from_h5(self.fs.attrs["flow"]).context def __str__(self): for attr, value in self.fs.attrs.items(): if attr in ('flow', '_collision'): print(attr + ": " + str(self._unpickle_from_h5(self.fs.attrs[attr]))) else: print(attr + ": " + str(value)) return "" def __len__(self): return self.shape[0] - self.skip_idx_to_target if self.target else self.shape[0] def __getitem__(self, idx): f = self.get_data(idx) target = [] if self.target: target = self.get_data(idx + self.skip_idx_to_target) if self.transform: f = self.transform(f) if self.target: target = self.transform(target) return (f, target, idx) if self.target else (f, idx) def __del__(self): self.fs.close()
[docs] def get_data(self, idx): return self.context.convert_to_tensor(self.fs["f"][idx])
[docs] def get_attr(self, attr): return self.fs.attrs[attr]
@staticmethod def _unpickle_from_h5(byte_str): return pickle.load(io.BytesIO(byte_str))