Source code for ilastik.applets.base.appletSerializer

###############################################################################
#   ilastik: interactive learning and segmentation toolkit
#
#       Copyright (C) 2011-2014, the ilastik developers
#                                <team@ilastik.org>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# In addition, as a special exception, the copyright holders of
# ilastik give you permission to combine ilastik with applets,
# workflows and plugins which are not covered under the GNU
# General Public License.
#
# See the LICENSE file for details. License information is also available
# on the ilastik web site at:
#		   http://ilastik.org/license.html
###############################################################################
import logging
logger = logging.getLogger(__name__)

from abc import ABCMeta

from ilastik.config import cfg as ilastik_config
from ilastik.utility.simpleSignal import SimpleSignal
from ilastik.utility.maybe import maybe
import os
import re
import tempfile
import h5py
import numpy
import warnings
import cPickle as pickle

from lazyflow.roi import TinyVector, roiToSlice, sliceToRoi
from lazyflow.utility import timeLogged
from lazyflow.slot import OutputSlot

#######################
# Convenience methods #
#######################

def getOrCreateGroup(parentGroup, groupName):
    """Returns parentGroup[groupName], creating first it if
    necessary.

    """

    return parentGroup.require_group(groupName)

def deleteIfPresent(parentGroup, name):
    """Deletes parentGroup[name], if it exists."""
    # Check first. If we try to delete a non-existent key, 
    # hdf5 will complain on the console.
    if name in parentGroup:
        del parentGroup[name]

def slicingToString(slicing):
    """Convert the given slicing into a string of the form
    '[0:1,2:3,4:5]'

    """
    strSlicing = '['
    for s in slicing:
        strSlicing += str(s.start)
        strSlicing += ':'
        strSlicing += str(s.stop)
        strSlicing += ','

    strSlicing = strSlicing[:-1] # Drop the last comma
    strSlicing += ']'
    return strSlicing

def stringToSlicing(strSlicing):
    """Parse a string of the form '[0:1,2:3,4:5]' into a slicing (i.e.
    list of slices)

    """
    slicing = []
    strSlicing = strSlicing[1:-1] # Drop brackets
    sliceStrings = strSlicing.split(',')
    for s in sliceStrings:
        ends = s.split(':')
        start = int(ends[0])
        stop = int(ends[1])
        slicing.append(slice(start, stop))

    return slicing


[docs]class SerialSlot(object): """Implements the logic for serializing a slot."""
[docs] def __init__(self, slot, inslot=None, name=None, subname=None, default=None, depends=None, selfdepends=True): """ :param slot: where to get data to save :param inslot: where to put loaded data. If None, it is the same as 'slot'. :param name: name used for the group in the hdf5 file. :param subname: used for creating subgroups for multislots. should be able to call subname.format(i), where i is an integer. :param default: DEPRECATED :param depends: a list of slots which must be ready before this slot can be serialized. If None, defaults to []. :param selfdepends: whether 'slot' should be added to 'depends' """ if slot.level > 1: # FIXME: recursive serialization, to support arbitrary levels raise Exception('slots of levels > 1 not supported') self.slot = slot if inslot is None: inslot = slot self.inslot = inslot self.default = default self.depends = maybe(depends, []) if selfdepends: self.depends.append(slot) if name is None: name = slot.name self.name = name if subname is None: subname = '{:04d}' self.subname = subname self._dirty = False self._bind() self.ignoreDirty = False
@property def dirty(self): return self._dirty @dirty.setter def dirty(self, isDirty): if not isDirty or (isDirty and not self.ignoreDirty): self._dirty = isDirty def setDirty(self, *args, **kwargs): self.dirty = True def _bind(self, slot=None): """Setup so that when slot is dirty, set appropriate dirty flag. """ slot = maybe(slot, self.slot) def doMulti(slot, index, size): slot[index].notifyDirty(self.setDirty) slot[index].notifyValueChanged(self.setDirty) if slot.level == 0: slot.notifyDirty(self.setDirty) slot.notifyValueChanged(self.setDirty) else: slot.notifyInserted(doMulti) slot.notifyRemoved(self.setDirty)
[docs] def shouldSerialize(self, group): """Whether to serialize or not.""" result = self.dirty result |= self.name not in group.keys() for s in self.depends: result &= s.ready() return result
[docs] def serialize(self, group): """Performs tasks common to all serializations, like changing dirty status. Do not override (unless for some reason this function does not do the right thing in your case). Instead override _serialize(). :param group: The parent group in which to create this slot's group. :type group: h5py.Group """ if not self.shouldSerialize(group): return deleteIfPresent(group, self.name) if self.slot.ready(): self._serialize(group, self.name, self.slot) self.dirty = False
@staticmethod def _saveValue(group, name, value): """Seperate so that subclasses can override, if necessary. For instance, SerialListSlot needs to save an extra attribute if the value is an empty list. """ group.create_dataset(name, data=value)
[docs] def _serialize(self, group, name, slot): """ :param group: The parent group. :type group: h5py.Group :param name: The name of the data or group :type name: string :param slot: the slot to serialize :type slot: SerialSlot """ if slot.level == 0: try: self._saveValue(group, name, slot.value) except: self._saveValue(group, name, slot(()).wait()) else: subgroup = group.create_group(name) for i, subslot in enumerate(slot): subname = self.subname.format(i) self._serialize(subgroup, subname, slot[i])
[docs] def deserialize(self, group): """Performs tasks common to all deserializations. Do not override (unless for some reason this function does not do the right thing in your case). Instead override _deserialize. :param group: The parent group in which to create this slot's group. :type group: h5py.Group """ if not self.name in group: return self._deserialize(group[self.name], self.inslot) self.dirty = False
@staticmethod def _getValue(subgroup, slot): val = subgroup[()] slot.setValue(val)
[docs] def _deserialize(self, subgroup, slot): """ :param subgroup: *not* the parent group. This slot's group. :type subgroup: h5py.Group """ if slot.level == 0: self._getValue(subgroup, slot) else: # Pair stored indexes with their keys, # e.g. [(0,'0'), (2, '2'), (3, '3')] # Note that in some cases an index might be intentionally skipped. indexes_to_keys = { int(k) : k for k in subgroup.keys() } # Ensure the slot is at least big enough to deserialize into. if indexes_to_keys.keys() == []: max_index = 0 else: max_index = max( indexes_to_keys.keys() ) if len(slot) < max_index+1: slot.resize(max_index+1) # Now retrieve the data for i, subslot in enumerate(slot): if i in indexes_to_keys: key = indexes_to_keys[i] # Sadly, we can't use the following assertion because it would break # backwards compatibility with a bug we used to have in the key names. #assert key == self.subname.format(i) self._deserialize(subgroup[key], subslot) else: # Since there was no data for this subslot in the project file, # we disconnect the subslot. subslot.disconnect()
####################################################### # some serial slots that are used in multiple applets # #######################################################
[docs]class SerialListSlot(SerialSlot): """As the name implies: used for serializing a list. The only differences from the base class are: - if deserializing fails, sets the slot value to []. - if it succeeds, applies a transform to every element of the list (for instance, to convert it to the proper type). """
[docs] def __init__(self, slot, inslot=None, name=None, subname=None, default=None, depends=None, selfdepends=True, transform=None, store_transform=None, iterable=list): """ :param transform: function applied to members on deserialization. """ # TODO: implement for multislots if slot.level > 0: raise NotImplementedError() super(SerialListSlot, self).__init__( slot, inslot, name, subname, default, depends, selfdepends ) if transform is None: transform = lambda x: x self.transform = transform self._iterable = iterable self._store_transform = store_transform if store_transform is None: self._store_transform = lambda x:x
def _saveValue(self, group, name, value): isempty = (len(value) == 0) if isempty: value = numpy.empty((1,)) sg = group.create_dataset(name, data=map(self._store_transform, value)) sg.attrs['isEmpty'] = isempty @timeLogged(logger, logging.DEBUG) def deserialize(self, group): logger.debug("Deserializing ListSlot: {}".format(self.name)) try: subgroup = group[self.name] except: if logger.isEnabledFor(logging.DEBUG): # Only show this warning when debugging serialization warnings.warn("Deserialization: Could not locate value for slot '{}'. Skipping.".format( self.name )) return if 'isEmpty' in subgroup.attrs and subgroup.attrs['isEmpty']: self.inslot.setValue( self._iterable([]) ) else: if len(subgroup.shape) == 0 or subgroup.shape[0] == 0: # How can this happen, anyway...? return else: self.inslot.setValue(self._iterable(map(self.transform, subgroup[()]))) self.dirty = False
[docs]class SerialBlockSlot(SerialSlot): """A slot which only saves nonzero blocks."""
[docs] def __init__(self, slot, inslot, blockslot, name=None, subname=None, default=None, depends=None, selfdepends=True, shrink_to_bb=False, compression_level=0): """ :param blockslot: provides non-zero blocks. :param shrink_to_bb: If true, reduce each block of data from the slot to its nonzero bounding box before feeding saving it. """ assert isinstance(slot, OutputSlot), "slot is of wrong type: '{}' is not an OutputSlot".format( slot.name ) super(SerialBlockSlot, self).__init__( slot, inslot, name, subname, default, depends, selfdepends ) self.blockslot = blockslot self._bind(slot) self._shrink_to_bb = shrink_to_bb self.compression_level = compression_level
def shouldSerialize(self, group): # Should this be a docstring? # # Must be overloaded as SerialBlockSlot does not serialize itself in the simple way that other SerialSlot do # as a consequence of the nesting of groups required. Follows the same logic as _serialize and checks to see # if each relevant subgroup has been created and if any are missing or their data is missing it should be # serialized. Otherwise, if everything is intact, it doesn't suggest serialization unless the state has changed. logger.debug("Checking whether to serialize BlockSlot: {}".format( self.name )) if self.dirty: logger.debug("BlockSlot \"" + self.name + "\" appears to be dirty. Should serialize.") return True # SerialSlot interchanges self.name and name when they frequently are the same thing. It is not clear if using # self.name would be acceptable here or whether name should be an input to shouldSerialize or if there should be # a _shouldSerialize method, which takes the name. if self.name not in group: logger.debug("Missing \"" + self.name + "\" in group \"" + repr(group) + "\" belonging to BlockSlot \"" + self.name + "\". Should serialize.") return True else: logger.debug("Found \"" + self.name + "\" in group \"" + repr(group) + "\" belonging to BlockSlot \"" + self.name + "\".") # Just because the group was serialized doesn't mean that the relevant data was. mygroup = group[self.name] num = len(self.blockslot) for index in range(num): subname = self.subname.format(index) # Check to se if each subname has been created as a group if subname not in mygroup: logger.debug("Missing \"" + subname + "\" from \"" + repr(mygroup) + "\" belonging to BlockSlot \"" + self.name + "\". Should serialize.") return True else: logger.debug("Found \"" + subname + "\" from \"" + repr(mygroup) + "\" belonging to BlockSlot \"" + self.name + "\".") subgroup = mygroup[subname] nonZeroBlocks = self.blockslot[index].value for blockIndex in xrange(len(nonZeroBlocks)): blockName = 'block{:04d}'.format(blockIndex) if blockName not in subgroup: logger.debug("Missing \"" + blockName + "\" from \"" + repr(subgroup) + "\". Should serialize.") return True else: logger.debug("Found \"" + blockName + "\" from \"" + repr(subgroup) + "\" belonging to BlockSlot \"" + self.name + "\".") logger.debug("Everything belonging to BlockSlot \"" + self.name + "\" appears to be in order. Should not serialize.") return False @timeLogged(logger, logging.DEBUG) def _serialize(self, group, name, slot): logger.debug("Serializing BlockSlot: {}".format( self.name )) mygroup = group.create_group(name) num = len(self.blockslot) for index in range(num): subname = self.subname.format(index) subgroup = mygroup.create_group(subname) nonZeroBlocks = self.blockslot[index].value for blockIndex, slicing in enumerate(nonZeroBlocks): if not isinstance(slicing[0], slice): slicing = roiToSlice(*slicing) block = self.slot[index][slicing].wait() blockName = 'block{:04d}'.format(blockIndex) if self._shrink_to_bb: nonzero_coords = numpy.nonzero(block) if len(nonzero_coords[0]) > 0: block_start = sliceToRoi( slicing, (0,)*len(slicing) )[0] block_bounding_box_start = numpy.array( map( numpy.min, nonzero_coords ) ) block_bounding_box_stop = 1 + numpy.array( map( numpy.max, nonzero_coords ) ) block_slicing = roiToSlice( block_bounding_box_start, block_bounding_box_stop ) bounding_box_roi = numpy.array([block_bounding_box_start, block_bounding_box_stop]) bounding_box_roi += block_start # Overwrite the vars that are written to the file slicing = roiToSlice(*bounding_box_roi) block = block[block_slicing] # If we have a masked array, convert it to a structured array so that h5py can handle it. if slot[index].meta.has_mask: mygroup.attrs["meta.has_mask"] = True block_group = subgroup.create_group(blockName) if self.compression_level: block_group.create_dataset("data", data=block.data, compression='gzip', compression_opts=compression_level) else: block_group.create_dataset("data", data=block.data) block_group.create_dataset( "mask", data=block.mask, compression="gzip", compression_opts=2 ) block_group.create_dataset("fill_value", data=block.fill_value) block_group.attrs['blockSlice'] = slicingToString(slicing) else: subgroup.create_dataset(blockName, data=block) subgroup[blockName].attrs['blockSlice'] = slicingToString(slicing) @timeLogged(logger, logging.DEBUG) def _deserialize(self, mygroup, slot): logger.debug("Deserializing BlockSlot: {}".format( self.name )) num = len(mygroup) if len(self.inslot) < num: self.inslot.resize(num) # Annoyingly, some applets store their groups with names like, img0,img1,img2,..,img9,img10,img11 # which means that sorted() needs a special key to avoid sorting img10 before img2 # We have to find the index and sort according to its numerical value. index_capture = re.compile(r'[^0-9]*(\d*).*') def extract_index(s): return int(index_capture.match(s).groups()[0]) for index, t in enumerate(sorted(mygroup.items(), key=lambda (k,v): extract_index(k))): groupName, labelGroup = t for blockData in labelGroup.values(): slicing = stringToSlicing(blockData.attrs['blockSlice']) # If it is suppose to be a masked array, # deserialize the pieces and rebuild the masked array. assert slot[index].meta.has_mask == mygroup.attrs.get("meta.has_mask"), \ "The slot and stored data have different values for" + \ " `has_mask`. They are" + \ " `bool(slot[index].meta.has_mask)`=" + \ repr(bool(slot[index].meta.has_mask)) + " and" + \ " `mygroup.attrs.get(\"meta.has_mask\", False)`=" + \ repr(mygroup.attrs.get("meta.has_mask", False)) + \ ". Please fix this to proceed with deserialization." if slot[index].meta.has_mask: blockArray = numpy.ma.masked_array( blockData["data"][()], mask=blockData["mask"][()], fill_value=blockData["fill_value"][()], shrink=False ) else: blockArray = blockData[...] self.inslot[index][slicing] = blockArray
class SerialHdf5BlockSlot(SerialBlockSlot): def _serialize(self, group, name, slot): mygroup = group.create_group(name) num = len(self.blockslot) for index in range(num): subname = self.subname.format(index) subgroup = mygroup.create_group(subname) cleanBlockRois = self.blockslot[index].value for roi in cleanBlockRois: # The protocol for hdf5 slots is that they create appropriately # named datasets within the subgroup that we provide via writeInto() req = self.slot[index]( *roi ) req.writeInto( subgroup ) req.wait() def _deserialize(self, mygroup, slot): num = len(mygroup) if len(self.inslot) < num: self.inslot.resize(num) # Annoyingly, some applets store their groups with names like, img0,img1,img2,..,img9,img10,img11 # which means that sorted() needs a special key to avoid sorting img10 before img2 # We have to find the index and sort according to its numerical value. index_capture = re.compile(r'[^0-9]*(\d*).*') def extract_index(s): return int(index_capture.match(s).groups()[0]) for index, t in enumerate(sorted(mygroup.items(), key=lambda (k,v): extract_index(k))): groupName, labelGroup = t assert extract_index(groupName) == index, "subgroup extraction order should be numerical order!" for blockRoiString, blockDataset in labelGroup.items(): blockRoi = eval(blockRoiString) roiShape = TinyVector(blockRoi[1]) - TinyVector(blockRoi[0]) assert roiShape == blockDataset.shape self.inslot[index][roiToSlice( *blockRoi )] = blockDataset
[docs]class SerialClassifierSlot(SerialSlot): """For saving a classifier. Here we assume the classifier is stored in the ."""
[docs] def __init__(self, slot, cache, inslot=None, name=None, default=None, depends=None, selfdepends=True): super(SerialClassifierSlot, self).__init__( slot, inslot, name, None, default, depends, selfdepends ) self.cache = cache if self.name is None: self.name = slot.name # We want to bind to the INPUT, not Output: # - if the input becomes dirty, we want to make sure the cache is deleted # - if the input becomes dirty and then the cache is reloaded, we'll save the classifier. self._bind(cache.Input)
[docs] def _serialize(self, group, name, slot): # Is the cache up-to-date? # if not, we'll just return (don't recompute the classifier just to save it) if self.cache._dirty: return classifier = self.cache.Output.value # Classifier can be None if there isn't any training data yet. if classifier is None: return classifier_group = group.create_group( name ) classifier.serialize_hdf5( classifier_group )
[docs] def deserialize(self, group): """ Have to override this to ensure that dirty is always set False. """ super(SerialClassifierSlot, self).deserialize(group) self.dirty = False
[docs] def _deserialize(self, classifierGroup, slot): try: classifier_type = pickle.loads( classifierGroup['pickled_type'][()] ) except KeyError: # For compatibility with old project files, choose the default classifier. from lazyflow.classifiers import ParallelVigraRfLazyflowClassifier classifier_type = ParallelVigraRfLazyflowClassifier try: classifier = classifier_type.deserialize_hdf5( classifierGroup ) except: warnings.warn( "Wasn't able to deserialize the saved classifier. " "It will need to be retrainied" ) return # Now force the classifier into our classifier cache. The # downstream operators (e.g. the prediction operator) can # use the classifier without inducing it to be re-trained. # (This assumes that the classifier we are loading is # consistent with the images and labels that we just # loaded. As soon as training input changes, it will be # retrained.) self.cache.forceValue( classifier )
class SerialPickledValueSlot(SerialSlot): """ For storing value slots whose data is a python object (not an array or a simple number). """ def __init__(self, slot): super(SerialPickledValueSlot, self).__init__(slot) @staticmethod def _saveValue(group, name, value): group.create_dataset(name, data=pickle.dumps(value)) @staticmethod def _getValue(subgroup, slot): val = subgroup[()] slot.setValue(pickle.loads(val)) class SerialCountingSlot(SerialSlot): """For saving a random forest classifier.""" def __init__(self, slot, cache, inslot=None, name=None, default=None, depends=None, selfdepends=True): super(SerialCountingSlot, self).__init__( slot, inslot, name, "wrapper{:04d}", default, depends, selfdepends ) self.cache = cache if self.name is None: self.name = slot.name if self.subname is None: self.subname = "wrapper{:04d}" # We want to bind to the INPUT, not Output: # - if the input becomes dirty, we want to make sure the cache is deleted # - if the input becomes dirty and then the cache is reloaded, we'll save the classifier. self._bind(cache.Input) def _serialize(self, group, name, slot): if self.cache._dirty: return classifier_forests = self.cache.Output.value # Classifier can be None if there isn't any training data yet. if classifier_forests is None: return for forest in classifier_forests: if forest is None: return # Due to non-shared hdf5 dlls, vigra can't write directly to # our open hdf5 group. Instead, we'll use vigra to write the # classifier to a temporary file. tmpDir = tempfile.mkdtemp() cachePath = os.path.join(tmpDir, 'tmp_classifier_cache.h5').replace('\\', '/') for i, forest in enumerate(classifier_forests): targetname = '{0}/{1}'.format(name, self.subname.format(i)) forest.writeHDF5(cachePath, targetname) # Open the temp file and copy to our project group with h5py.File(cachePath, 'r') as cacheFile: group.copy(cacheFile[name], name) os.remove(cachePath) os.rmdir(tmpDir) def deserialize(self, group): """ Have to override this to ensure that dirty is always set False. """ super(SerialCountingSlot, self).deserialize(group) self.dirty = False def _deserialize(self, classifierGroup, slot): # Due to non-shared hdf5 dlls, vigra can't read directly # from our open hdf5 group. Instead, we'll copy the # classfier data to a temporary file and give it to vigra. tmpDir = tempfile.mkdtemp() cachePath = os.path.join(tmpDir, 'tmp_classifier_cache.h5').replace('\\', '/') with h5py.File(cachePath, 'w') as cacheFile: cacheFile.copy(classifierGroup, self.name) try: forests = [] for name, forestGroup in sorted(classifierGroup.items()): targetname = '{0}/{1}'.format(self.name, name) #forests.append(vigra.learning.RandomForest(cachePath, targetname)) from ilastik.applets.counting.countingsvr import SVR forests.append(SVR.load(cachePath, targetname)) except: warnings.warn( "Wasn't able to deserialize the saved classifier. " "It will need to be retrainied" ) return finally: os.remove(cachePath) os.rmdir(tmpDir) # Now force the classifier into our classifier cache. The # downstream operators (e.g. the prediction operator) can # use the classifier without inducing it to be re-trained. # (This assumes that the classifier we are loading is # consistent with the images and labels that we just # loaded. As soon as training input changes, it will be # retrained.) self.cache.forceValue(numpy.array(forests)) class SerialDictSlot(SerialSlot): """For saving a dictionary.""" def __init__(self, slot, inslot=None, name=None, subname=None, default=None, depends=None, selfdepends=True, transform=None): """ :param transform: a function called on each key before inserting it into the dictionary. """ super(SerialDictSlot, self).__init__( slot, inslot, name, subname, default, depends, selfdepends ) if transform is None: transform = lambda x: x self.transform = transform def _saveValue(self, group, name, value): sg = group.create_group(name) for key, v in value.iteritems(): if isinstance(v, dict): self._saveValue(sg, key, v) else: sg.create_dataset(str(key), data=v) def _getValueHelper(self, subgroup): result = {} for key in subgroup.keys(): if isinstance(subgroup[key], h5py.Group): value = self._getValueHelper(subgroup[key]) else: value = subgroup[key][()] result[self.transform(key)] = value return result def _getValue(self, subgroup, slot): result = self._getValueHelper(subgroup) try: slot.setValue(result) except AssertionError as e: warnings.warn('setValue() failed. message: {}'.format(e.message)) class SerialClassifierFactorySlot(SerialSlot): def __init__(self, slot, name=None): super( SerialClassifierFactorySlot, self ).__init__( slot, name=name ) self._failed_to_deserialize = False assert slot.ready(), \ "ClassifierFactory slots must be given a default value "\ "(in case the classifier can't be deserialized in a future version of ilastik)." def _saveValue(self, group, name, value): pickled = pickle.dumps( value ) group.create_dataset(name, data=pickled) self._failed_to_deserialize = False def shouldSerialize(self, group): if self._failed_to_deserialize: return True else: return super(SerialClassifierFactorySlot, self).shouldSerialize(group) def _getValue(self, dset, slot): pickled = dset[()] try: # Attempt to unpickle value = pickle.loads(pickled) # Verify that the VERSION of the classifier factory in the currently executing code # has not changed since this classifier was stored. assert 'VERSION' in value.__dict__ and value.VERSION == type(value).VERSION except: self._failed_to_deserialize = True warnings.warn("This project file uses an old or unsupported classifier storage format. " "The classifier will be stored in the new format when you save your project.") else: slot.setValue( value ) class SerialPickleableSlot(SerialSlot): def __init__(self, slot, version, default, name=None): super( SerialPickleableSlot, self ).__init__( slot, name=name ) self._failed_to_deserialize = False self._version = version self._default = default def _saveValue(self, group, name, value): pickled = pickle.dumps( value ) dset = group.create_dataset(name, data=pickled) dset.attrs['version'] = self._version self._failed_to_deserialize = False def shouldSerialize(self, group): if self._failed_to_deserialize: return True else: return super(SerialPickleableSlot, self).shouldSerialize(group) def _getValue(self, dset, slot): try: # first check that the version of the deserialized and the expected value are the same loaded_version = dset.attrs['version'] assert loaded_version == self._version # Attempt to unpickle pickled = dset[()] value = pickle.loads(pickled) except: self._failed_to_deserialize = True warnings.warn("This project file uses an old or unsupported storage format. " "When save the project the next time, it will be stored in the new format.") slot.setValue(self._default) else: slot.setValue( value ) #################################### # the base applet serializer class # ####################################
[docs]class AppletSerializer(object): """ Base class for all AppletSerializers. """ # Force subclasses to override abstract methods and properties __metaclass__ = ABCMeta base_initialized = False # override if necessary version = "0.1" class IncompatibleProjectVersionError(Exception): pass ######################### # Semi-abstract methods # #########################
[docs] def _serializeToHdf5(self, topGroup, hdf5File, projectFilePath): """Child classes should override this function, if necessary. """ pass
[docs] def _deserializeFromHdf5(self, topGroup, groupVersion, hdf5File, projectFilePath, headless = False): """Child classes should override this function, if necessary. """ pass
############################# # Base class implementation # ############################# def __init__(self, topGroupName, slots=None, operator=None): """Constructor. Subclasses must call this method in their own __init__ functions. If they fail to do so, the shell raises an exception. Parameters: :param topGroupName: name of this applet's data group in the file. Defaults to the name of the operator. :param slots: a list of SerialSlots """ self.progressSignal = SimpleSignal() # Signature: emit(percentComplete) self.base_initialized = True self.topGroupName = topGroupName self.serialSlots = maybe(slots, []) self.operator = operator self.caresOfHeadless = False # should _deserializeFromHdf5 should be called with headless-argument? self._ignoreDirty = False
[docs] def isDirty(self): """Returns true if the current state of this item (in memory) does not match the state of the HDF5 group on disk. Subclasses only need override this method if ORing the flags is not enough. """ return any(list(ss.dirty for ss in self.serialSlots))
[docs] def shouldSerialize(self, hdf5File): """Whether to serialize or not.""" if self.isDirty(): return True # Need to check if slots should be serialized. First must verify that self.topGroupName is not an empty string # (as this seems to happen sometimes). if self.topGroupName: topGroup = getOrCreateGroup(hdf5File, self.topGroupName) return any([ss.shouldSerialize(topGroup) for ss in self.serialSlots]) return False
@property def ignoreDirty(self): return self._ignoreDirty @ignoreDirty.setter def ignoreDirty(self, value): self._ignoreDirty = value for ss in self.serialSlots: ss.ignoreDirty = value
[docs] def progressIncrement(self, group=None): """Get the percentage progress for each slot. :param group: If None, all all slots are assumed to be processed. Otherwise, decides for each slot by calling slot.shouldSerialize(group). """ if group is None: nslots = len(self.serialSlots) else: nslots = sum(ss.shouldSerialize(group) for ss in self.serialSlots) if nslots == 0: return 0 return divmod(100, nslots)[0]
[docs] def serializeToHdf5(self, hdf5File, projectFilePath): """Serialize the current applet state to the given hdf5 file. Subclasses should **not** override this method. Instead, subclasses override the 'private' version, *_serializetoHdf5* :param hdf5File: An h5py.File handle to the project file, which should already be open :param projectFilePath: The path to the given file handle. (Most serializers do not use this parameter.) """ topGroup = getOrCreateGroup(hdf5File, self.topGroupName) progress = 0 self.progressSignal.emit(progress) # Set the version key = 'StorageVersion' deleteIfPresent(topGroup, key) topGroup.create_dataset(key, data=self.version) try: inc = self.progressIncrement(topGroup) for ss in self.serialSlots: ss.serialize(topGroup) progress += inc self.progressSignal.emit(progress) # Call the subclass to do remaining work, if any self._serializeToHdf5(topGroup, hdf5File, projectFilePath) finally: self.progressSignal.emit(100)
[docs] def deserializeFromHdf5(self, hdf5File, projectFilePath, headless = False): """Read the the current applet state from the given hdf5File handle, which should already be open. Subclasses should **not** override this method. Instead, subclasses override the 'private' version, *_deserializeFromHdf5* :param hdf5File: An h5py.File handle to the project file, which should already be open :param projectFilePath: The path to the given file handle. (Most serializers do not use this parameter.) :param headless: Are we called in headless mode? (in headless mode corrupted files cannot be fixed via the GUI) """ self.progressSignal.emit(0) # If the top group isn't there, call initWithoutTopGroup try: topGroup = hdf5File[self.topGroupName] groupVersion = topGroup['StorageVersion'][()] except KeyError: topGroup = None groupVersion = None try: if topGroup is not None: inc = self.progressIncrement() for ss in self.serialSlots: ss.deserialize(topGroup) self.progressSignal.emit(inc) # Call the subclass to do remaining work if self.caresOfHeadless: self._deserializeFromHdf5(topGroup, groupVersion, hdf5File, projectFilePath, headless) else: self._deserializeFromHdf5(topGroup, groupVersion, hdf5File, projectFilePath) else: self.initWithoutTopGroup(hdf5File, projectFilePath) finally: self.progressSignal.emit(100)
[docs] def repairFile(self,path,filt = None): """get new path to lost file""" from PyQt4.QtGui import QFileDialog,QMessageBox from volumina.utility import encode_from_qstring text = "The file at {} could not be found any more. Do you want to search for it at another directory?".format(path) logger.info(text) c = QMessageBox.critical(None, "update external data",text, QMessageBox.Ok | QMessageBox.Cancel) if c == QMessageBox.Cancel: raise RuntimeError("Could not find external data: " + path) options = QFileDialog.Options() if ilastik_config.getboolean("ilastik", "debug"): options |= QFileDialog.DontUseNativeDialog fileName = QFileDialog.getOpenFileName( None, "repair files", path, filt, options=options) if fileName.isEmpty(): raise RuntimeError("Could not find external data: " + path) else: return encode_from_qstring(fileName)
####################### # Optional methods # #######################
[docs] def initWithoutTopGroup(self, hdf5File, projectFilePath): """Optional override for subclasses. Called when there is no top group to deserialize. """ pass
[docs] def updateWorkingDirectory(self,newdir,olddir): """Optional override for subclasses. Called when the working directory is changed and relative paths have to be updated. Child Classes should overwrite this method if they store relative paths.""" pass