本文共 7847 字,大约阅读时间需要 26 分钟。
Broadcast(广播变量)是只读变量,它会将数据缓存在每个节点上,而不是每个Task去获取它的复制副本。这样可以降低计算过程中的网络开销。
broadcast的基本使用包括创建和读取。
创建scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0)
读取
scala> broadcastVar.valueres0: Array[Int] = Array(1, 2, 3)
BroadcastManager是用来管理Broadcast,该实例对象是在SparkEnv.scala的create方法中创建的。
private def create( conf: SparkConf, executorId: String, bindAddress: String, advertiseAddress: String, port: Option[Int], isLocal: Boolean, numUsableCores: Int, ioEncryptionKey: Option[Array[Byte]], listenerBus: LiveListenerBus = null, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { ... // 创建broadcastManager val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) // 创建mapOutputTracker val mapOutputTracker = if (isDriver) { new MapOutputTrackerMaster(conf, broadcastManager, isLocal) } else { new MapOutputTrackerWorker(conf) } ... }
BroadcastManager构造方法中会调用initialize方法
private def initialize() { synchronized { if (!initialized) { // 初始化TorrentBroadcastFactory broadcastFactory = new TorrentBroadcastFactory // 调用TorrentBroadcastFactory的initialize方法 broadcastFactory.initialize(isDriver, conf, securityManager) initialized = true } } }
只是TorrentBroadcastFactory的initialize实际什么都没做而已。
private[spark] class TorrentBroadcastFactory extends BroadcastFactory { override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }}
创建broadcast
broadcast的创建是由SparkContext.scala的broadcast方法完成的。该方法实际上调用了BroadcastManager的newBroadcast方法。
def broadcast[T: ClassTag](value: T): Broadcast[T] = { assertNotStopped() require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass), "Can not directly broadcast RDDs; instead, call collect() and broadcast the result.") val bc = env.broadcastManager.newBroadcast[T](value, isLocal) val callSite = getCallSite logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) cleaner.foreach(_.registerBroadcastForCleanup(bc)) bc }
newBroadcast方法中继续调用broadcastFactory的newBroadcast方法,实际上调用的是TorrentBroadcastFactory的newBroadcast方法。
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) }
TorrentBroadcastFactory的newBroadcast方法会创建TorrentBroadcast对象。
private[spark] class TorrentBroadcastFactory extends BroadcastFactory { ...override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = { new TorrentBroadcast[T](value_, id) }...}
在TorrentBroadcast的构造方法中会调用writeBlocks方法,该方法将广播变量的值写入到Driver节点的blockManager中,以便Executor节点获取广播变量的值。
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) extends Broadcast[T](id) with Logging with Serializable { ... private val numBlocks: Int = writeBlocks(obj) private def writeBlocks(value: T): Int = { import StorageLevel._ val blockManager = SparkEnv.get.blockManager // 在Driver中存储广播变量的副本,以便在Driver上运行的任务不会创建广播变量值的副本。 if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) { throw new SparkException(s"Failed to store $broadcastId in BlockManager") } //将对象序列化为字节块 val blocks = TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) if (checksumEnabled) { checksums = new Array[Int](blocks.length) } blocks.zipWithIndex.foreach { case (block, i) => if (checksumEnabled) { checksums(i) = calcChecksum(block) } val pieceId = BroadcastBlockId(id, "piece" + i) val bytes = new ChunkedByteBuffer(block.duplicate()) // 将字节块保存到BlockManager if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) { throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager") } } blocks.length } }
读取broadcast
broadcast 方法调用 value方法时,会调用TorrentBroadcast的getValue方法,最终会调用readBroadcastBlock方法。readBroadcastBlock的执行流程如下图所示:
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) extends Broadcast[T](id) with Logging with Serializable { @transient private lazy val _value: T = readBroadcastBlock() override protected def getValue() = { _value } private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { val broadcastCache = SparkEnv.get.broadcastManager.cachedValues // 如果缓存中有,则从缓存获取 Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse { setConf(SparkEnv.get.conf) val blockManager = SparkEnv.get.blockManager // 从本地BlockManager获取 blockManager.getLocalValues(broadcastId) match { case Some(blockResult) => if (blockResult.data.hasNext) { val x = blockResult.data.next().asInstanceOf[T] releaseLock(broadcastId) if (x != null) { //将数据写入本地缓存 broadcastCache.put(broadcastId, x) } x } else { throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") } case None => logInfo("Started reading broadcast variable " + id) val startTimeMs = System.currentTimeMillis() // 远程获取数据 val blocks = readBlocks() logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) try { val obj = TorrentBroadcast.unBlockifyObject[T]( blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) val storageLevel = StorageLevel.MEMORY_AND_DISK if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { throw new SparkException(s"Failed to store $broadcastId in BlockManager") } if (obj != null) { //将数据写入本地缓存 broadcastCache.put(broadcastId, obj) } obj } finally { blocks.foreach(_.dispose()) } } } } }
readBlocks方法从会随机选择一个远程节点获取数据,这样做的好处是可以避免大量Executor同时从Driver拉取数据而造成的数据热点问题。
private def readBlocks(): Array[BlockData] = { val blocks = new Array[BlockData](numBlocks) val bm = SparkEnv.get.blockManager // 随机选择一个远程节点拉取数据 for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { val pieceId = BroadcastBlockId(id, "piece" + pid) logDebug(s"Reading piece $pieceId of $broadcastId") bm.getLocalBytes(pieceId) match { case Some(block) => blocks(pid) = block releaseLock(pieceId) case None => bm.getRemoteBytes(pieceId) match { case Some(b) => if (checksumEnabled) { val sum = calcChecksum(b.chunks(0)) if (sum != checksums(pid)) { throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" + s" $sum != ${checksums(pid)}") } } if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { throw new SparkException( s"Failed to store $pieceId of $broadcastId in local BlockManager") } blocks(pid) = new ByteBufferBlockData(b, true) case None => throw new SparkException(s"Failed to get $pieceId of $broadcastId") } } } blocks }}
转载地址:http://arcmb.baihongyu.com/