分类
ASRT Python学习

为Keras包装一个线程安全的数据生成器

我们在使用Keras训练深度学习模型时,往往不能一次将数据全部加载进内存中,那样会导致内存不足的问题。包括Keras在内的深度学习框架提供了动态数据加载的模式,也就是说,需要使用到哪些数据时,才会加载哪些数据,而Keras需要我们自己定义一个数据生成器,并通过多线程的机制调用我们传入的数据生成器,克服硬盘的IO速度瓶颈,以实现数据的动态加载。

在需要某些数据时,才将这些数据读入,缓解了内存远远的压力,并使得训练大规模数据成为了现实。但是由此带来一个问题,我们都知道,在计算机体系结构中,速度最快的是运算器(比如CPU),其次是CPU中的各种寄存器,然后是高速缓存Cache,再其次是内存,速度最慢的是各种IO过程。很不幸,硬盘的读写就是速度最慢的IO过程,尤其是普通机械硬盘,与其“同病相怜”的还有外设的IO和网络传输的IO等。所以,Keras框架为了解决动态加载数据时IO速度的瓶颈,默认情况下会通过多线程来加载数据,模型在参数更新的同时,就会将接下来的若干批训练样本加载完毕。

当然一般情况下,我们随便写一个串行的数据生成器,在单例模式中,这没有任何问题。但是在多线程机制下,对同一个对象进行操作,有一定概率会出现“打架”的现象,随着时间窗口的密度升高时,发生冲突的概率越来越大。Python在检测到问题后,会自动抛出异常,提示我们该生成器正在运行。为了避免这种问题的发生,很多资料都提到将多线程或多进程操作改为单线程或单进程操作,比如对Keras的函数”fit_generator”设置参数” use_multiprocessing=False, worker=0”等,但这是不明智的,由于串行加载数据时大量地等待硬盘的IO,增加深度学习训练的时间开销,所以多线程是有必要的。多线程的线程安全详细原理和解决方案请查看我的上一篇博客文章:通过同步和加锁解决多线程的线程安全问题

通过上一篇博客我们可以知道,通过加锁,我们可以控制对某个资源的访问,以实现线程安全的运行过程。所以,在使用python的生成器时,我们也可以通过加锁的方式解决该问题,代码如下:

ThreadingSafetyIter.py

import threading
'''
    A generic iterator and generator that takes any iterator and wrap it to make it thread safe.
    This method was introducted by Anand Chitipothu in http://anandology.com/blog/using-iterators-and-generators/
    but was not compatible with python 3. This modified version is now compatible and works both in python 2.8 and 3.0 
'''
class threadsafe_iter:
    """Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()


    def __iter__(self):
        return self


    def __next__(self):
        with self.lock:
            return self.it.__next__()


def threadsafe_generator(f):
    """A decorator that takes a generator function and makes it thread-safe.
    """
    def g(*a, **kw):
        return threadsafe_iter(f(*a, **kw))
    return g


调用方法:

from ThreadingSafetyIter import threadsafe_generator
'''
    Usage Examples. Here's how to use @threadsafe_generator to make any generator thread safe:
'''
@threadsafe_generator
def count():
    i = 0
    while True:
        i += 1
        yield i

代码来源:

https://gist.github.com/platdrag/e755f3947552804c42633a99ffd325d4

版权声明
本博客的文章除特别说明外均为原创,本人版权所有。欢迎转载,转载请注明作者及来源链接,谢谢。
本文地址: https://blog.ailemon.net/2019/05/20/a-thread-safety-warped-generator-for-keras/
All articles are under Attribution-NonCommercial-ShareAlike 4.0

关注“AI柠檬博客”微信公众号,及时获取你最需要的干货。


“为Keras包装一个线程安全的数据生成器”上的2条回复

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

6 − 5 =

如果您是第一次在本站发布评论,内容将在博主审核后显示,请耐心等待