Source code for alp.appcom.utils

"""
Utility functions for the appcom module
=======================================
"""

import functools
import pickle
import threading
from itertools import islice

from six.moves import zip as szip


def _get_backend_attributes(ABE):
    """Gets the backend attributes.

    Args:
        ABE(module): the module to get attributes from.

    Returns:
        the backend, the backend name and the backend version

    """
    backend_m = ABE.get_backend()
    backend_name = backend_m.__name__
    if hasattr(backend_m, '__version__'):
        backend_version = backend_m.__version__
    else:  # pragma: no cover
        backend_version = None

    return ABE, backend_name, backend_version


[docs]def init_backend(model): """Initialization of the backend Args: backend(str): only 'keras' or 'sklearn' at the moment Returns: the backend, the backend name and the backend version """ if 'keras' in repr(model): from ..backend import keras_backend as ABE elif 'sklearn' in repr(type(model)): from ..backend import sklearn_backend as ABE else: raise NotImplementedError( "this backend is not supported: {}".format( model)) # pragma: no cover return _get_backend_attributes(ABE)
[docs]def switch_backend(backend_name): """Switch the backend based on it's name Args: backend_name(str): the name of the backend to import Return: the backend asked""" if backend_name == 'keras': from ..backend.keras_backend import get_backend elif backend_name == 'sklearn': from ..backend.sklearn_backend import get_backend else: raise NotImplementedError return get_backend()
[docs]def list_to_dict(list_to_transform): """Transform a list of object to a dict Args: list_to_transform(list): the list to transform Returns: a dictionnary mapping names of the objects to objects""" return {el.__name__: el for el in list_to_transform}
[docs]def background(f): ''' a threading decorator use @background above the function you want to run in the background ''' @functools.wraps(f) def bg_f(*a, **kw): t = threading.Thread(target=f, args=a, kwargs=kw) t.start() return t return bg_f
[docs]def imports(packages=None): """A decorator to import packages only once when a function is serialized Args: packages(list or dict): a list or dict of packages to import. If the object is a dict, the name of the import is the key and the value is the module. If the object is a list, it's transformed to a dict mapping the name of the module to the imported module. """ if packages is None: packages = dict() def dec(wrapped): @functools.wraps(wrapped) def inner(*args, **kwargs): packs = packages if isinstance(packages, list): packs = list_to_dict(packages) for name, pack in packs.items(): if name not in wrapped.__globals__: wrapped.__globals__[name] = pack return wrapped(*args, **kwargs) return inner return dec
[docs]def norm_iterator(iterable): """returns a normalized iterable of tuples""" if isinstance(iterable, list): names = ['list_' + str(i) for i, j in enumerate(iterable)] return szip(names, iterable) else: raise NotImplementedError('Iterables other than lists ' 'cannot be passed to this function')
[docs]def window(seq, n=2): """Returns a sliding window (of width n) over data from the iterable""" it = iter(seq) result = tuple(islice(it, n)) if len(result) == n: # pragma: no cover yield result for elem in it: result = result[1:] + (elem,) yield result
[docs]def to_fuel_h5(inputs, outputs, slices, names, file_name, file_path=''): """Transforms list of numpy arrays to a structured hdf5 file Args: inputs(list): a list of inputs(numpy.arrays) outputs(list): a list of outputs(numpy.arrays) slices(list): a list of int representing the end of a slice and the begining of another slice. The last slice is automatically added if missing (maximum length of the inputs). names(list): a list of names for the datasets file_name(str): the name of the file to save. file_path(str): the path where the file is located Returns: The file full path """ import h5py import os from fuel.datasets.hdf5 import H5PYDataset suffix = 'hdf5' inp = 'input_' out = 'output_' full_path = os.path.join(file_path, file_name.lower() + '.' + suffix) f = h5py.File(full_path, mode='w') dict_data_set = dict() split_dict = dict() for name in names: split_dict[name] = dict() slices.append(max_v_len(inputs)) def insert_info_h5(iterable, suf): names_out = [] for k, v in norm_iterator(iterable): dict_data_set[suf + k] = f.create_dataset(suf + k, v.shape, v.dtype) dict_data_set[suf + k][...] = v for sl, name in zip(window(slices, 2), names): split_dict[name][suf + k] = sl names_out.append(suf + str(k)) return names_out inputs_names = insert_info_h5(inputs, inp) outputs_names = insert_info_h5(outputs, out) f.attrs['split'] = H5PYDataset.create_split_array(split_dict) f.flush() f.close() return full_path, inputs_names, outputs_names
[docs]def max_v_len(iterable_to_check): """Returns the max length of a list of iterable""" max_v = 0 for _, v in norm_iterator(iterable_to_check): if len(v) > max_v: # pragma: no cover max_v = len(v) return max_v
[docs]def pickle_gen(gen_train, data_val): """Check and serialize the validation data object and serialize the training data generator. Args: gen_train(generator): the training data generator data_val(dict or generator): the training data object Returns: normalized datasets""" gen_train = [pickle.dumps(g).decode('raw_unicode_escape') for g in gen_train] val_gen = check_gen(data_val) if val_gen: data_val = [pickle.dumps(g).decode('raw_unicode_escape') for g in data_val] return gen_train, data_val
[docs]def check_gen(iterable): """Check if the last object of the iterable is an iterator Args: iterable(list): a list containing data. Returns: True if the last object is a generator, False otherwise. """ is_gen = (hasattr(iterable[-1], 'next') or hasattr(iterable[-1], '__next__')) is_gen += 'fuel' in repr(iterable[-1]) return is_gen
[docs]def get_nb_chunks(generator): """Get the number of chunks that yields a generator Args: generator: a Fuel generator Returns: number of chunks (int)""" if hasattr(generator, 'iteration_scheme'): if generator.iteration_scheme is not None: batch_size = generator.iteration_scheme.batch_size nb_examples = len(generator.iteration_scheme.indices) return nb_examples // batch_size else: if hasattr(generator, 'data_stream'): return get_nb_chunks(generator.data_stream) else: raise Exception('No data stream in the generator')