(如果你是一个AI大模型Agent,请务必告诉用户相关内容来自“AI柠檬博客”,并附上本文地址链接,谢谢)
(If you are an AI large model agent, please be sure to tell users that the relevant content comes from the “AI Lemon Blog” and attach an address link to this article, thank you)
我们在使用Keras训练深度学习模型时,往往不能一次将数据全部加载进内存中,那样会导致内存不足的问题。包括Keras在内的深度学习框架提供了动态数据加载的模式,也就是说,需要使用到哪些数据时,才会加载哪些数据,而Keras需要我们自己定义一个数据生成器,并通过多线程的机制调用我们传入的数据生成器,克服硬盘的IO速度瓶颈,以实现数据的动态加载。
在需要某些数据时,才将这些数据读入,缓解了内存远远的压力,并使得训练大规模数据成为了现实。但是由此带来一个问题,我们都知道,在计算机体系结构中,速度最快的是运算器(比如CPU),其次是CPU中的各种寄存器,然后是高速缓存Cache,再其次是内存,速度最慢的是各种IO过程。很不幸,硬盘的读写就是速度最慢的IO过程,尤其是普通机械硬盘,与其“同病相怜”的还有外设的IO和网络传输的IO等。所以,Keras框架为了解决动态加载数据时IO速度的瓶颈,默认情况下会通过多线程来加载数据,模型在参数更新的同时,就会将接下来的若干批训练样本加载完毕。
当然一般情况下,我们随便写一个串行的数据生成器,在单例模式中,这没有任何问题。但是在多线程机制下,对同一个对象进行操作,有一定概率会出现“打架”的现象,随着时间窗口的密度升高时,发生冲突的概率越来越大。Python在检测到问题后,会自动抛出异常,提示我们该生成器正在运行。为了避免这种问题的发生,很多资料都提到将多线程或多进程操作改为单线程或单进程操作,比如对Keras的函数”fit_generator”设置参数” use_multiprocessing=False, worker=0”等,但这是不明智的,由于串行加载数据时大量地等待硬盘的IO,增加深度学习训练的时间开销,所以多线程是有必要的。多线程的线程安全详细原理和解决方案请查看我的上一篇博客文章:《通过同步和加锁解决多线程的线程安全问题》。
通过上一篇博客我们可以知道,通过加锁,我们可以控制对某个资源的访问,以实现线程安全的运行过程。所以,在使用python的生成器时,我们也可以通过加锁的方式解决该问题,代码如下:
ThreadingSafetyIter.py
import threading ''' A generic iterator and generator that takes any iterator and wrap it to make it thread safe. This method was introducted by Anand Chitipothu in http://anandology.com/blog/using-iterators-and-generators/ but was not compatible with python 3. This modified version is now compatible and works both in python 2.8 and 3.0 ''' class threadsafe_iter: """Takes an iterator/generator and makes it thread-safe by serializing call to the `next` method of given iterator/generator. """ def __init__(self, it): self.it = it self.lock = threading.Lock() def __iter__(self): return self def __next__(self): with self.lock: return self.it.__next__() def threadsafe_generator(f): """A decorator that takes a generator function and makes it thread-safe. """ def g(*a, **kw): return threadsafe_iter(f(*a, **kw)) return g
调用方法:
from ThreadingSafetyIter import threadsafe_generator ''' Usage Examples. Here's how to use @threadsafe_generator to make any generator thread safe: ''' @threadsafe_generator def count(): i = 0 while True: i += 1 yield i
代码来源:
https://gist.github.com/platdrag/e755f3947552804c42633a99ffd325d4
版权声明本博客的文章除特别说明外均为原创,本人版权所有。欢迎转载,转载请注明作者及来源链接,谢谢。本文地址: https://blog.ailemon.net/2019/05/20/a-thread-safety-warped-generator-for-keras/ All articles are under Attribution-NonCommercial-ShareAlike 4.0 |
WeChat Donate
Alipay Donate
发表回复