博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Spark Broadcast源代码分析
阅读量:2427 次
发布时间:2019-05-10

本文共 7847 字,大约阅读时间需要 26 分钟。

1. Broadcast 简介

Broadcast(广播变量)是只读变量,它会将数据缓存在每个节点上,而不是每个Task去获取它的复制副本。这样可以降低计算过程中的网络开销。

broadcast的基本使用包括创建和读取。

创建

scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0)

读取

scala> broadcastVar.valueres0: Array[Int] = Array(1, 2, 3)
2. BroadcastManager初始化

BroadcastManager是用来管理Broadcast,该实例对象是在SparkEnv.scala的create方法中创建的。

private def create(      conf: SparkConf,      executorId: String,      bindAddress: String,      advertiseAddress: String,      port: Option[Int],      isLocal: Boolean,      numUsableCores: Int,      ioEncryptionKey: Option[Array[Byte]],      listenerBus: LiveListenerBus = null,      mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
... // 创建broadcastManager val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) // 创建mapOutputTracker val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster(conf, broadcastManager, isLocal) } else {
new MapOutputTrackerWorker(conf) } ... }

BroadcastManager构造方法中会调用initialize方法

private def initialize() {
synchronized {
if (!initialized) {
// 初始化TorrentBroadcastFactory broadcastFactory = new TorrentBroadcastFactory // 调用TorrentBroadcastFactory的initialize方法 broadcastFactory.initialize(isDriver, conf, securityManager) initialized = true } } }

只是TorrentBroadcastFactory的initialize实际什么都没做而已。

private[spark] class TorrentBroadcastFactory extends BroadcastFactory {
override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
}}
3. 创建broadcast和读取broadcast

创建broadcast

broadcast的创建是由SparkContext.scala的broadcast方法完成的。该方法实际上调用了BroadcastManager的newBroadcast方法。

def broadcast[T: ClassTag](value: T): Broadcast[T] = {
assertNotStopped() require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass), "Can not directly broadcast RDDs; instead, call collect() and broadcast the result.") val bc = env.broadcastManager.newBroadcast[T](value, isLocal) val callSite = getCallSite logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) cleaner.foreach(_.registerBroadcastForCleanup(bc)) bc }

newBroadcast方法中继续调用broadcastFactory的newBroadcast方法,实际上调用的是TorrentBroadcastFactory的newBroadcast方法。

def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) }

TorrentBroadcastFactory的newBroadcast方法会创建TorrentBroadcast对象。

private[spark] class TorrentBroadcastFactory extends BroadcastFactory {
...override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = {
new TorrentBroadcast[T](value_, id) }...}

在TorrentBroadcast的构造方法中会调用writeBlocks方法,该方法将广播变量的值写入到Driver节点的blockManager中,以便Executor节点获取广播变量的值。

private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)  extends Broadcast[T](id) with Logging with Serializable {
... private val numBlocks: Int = writeBlocks(obj) private def writeBlocks(value: T): Int = {
import StorageLevel._ val blockManager = SparkEnv.get.blockManager // 在Driver中存储广播变量的副本,以便在Driver上运行的任务不会创建广播变量值的副本。 if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager") } //将对象序列化为字节块 val blocks = TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) if (checksumEnabled) {
checksums = new Array[Int](blocks.length) } blocks.zipWithIndex.foreach {
case (block, i) => if (checksumEnabled) {
checksums(i) = calcChecksum(block) } val pieceId = BroadcastBlockId(id, "piece" + i) val bytes = new ChunkedByteBuffer(block.duplicate()) // 将字节块保存到BlockManager if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager") } } blocks.length } }

读取broadcast

broadcast 方法调用 value方法时,会调用TorrentBroadcast的getValue方法,最终会调用readBroadcastBlock方法。

readBroadcastBlock的执行流程如下图所示:

在这里插入图片描述

private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)  extends Broadcast[T](id) with Logging with Serializable {
@transient private lazy val _value: T = readBroadcastBlock() override protected def getValue() = {
_value } private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
val broadcastCache = SparkEnv.get.broadcastManager.cachedValues // 如果缓存中有,则从缓存获取 Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse {
setConf(SparkEnv.get.conf) val blockManager = SparkEnv.get.blockManager // 从本地BlockManager获取 blockManager.getLocalValues(broadcastId) match {
case Some(blockResult) => if (blockResult.data.hasNext) {
val x = blockResult.data.next().asInstanceOf[T] releaseLock(broadcastId) if (x != null) {
//将数据写入本地缓存 broadcastCache.put(broadcastId, x) } x } else {
throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") } case None => logInfo("Started reading broadcast variable " + id) val startTimeMs = System.currentTimeMillis() // 远程获取数据 val blocks = readBlocks() logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) try {
val obj = TorrentBroadcast.unBlockifyObject[T]( blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) val storageLevel = StorageLevel.MEMORY_AND_DISK if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager") } if (obj != null) {
//将数据写入本地缓存 broadcastCache.put(broadcastId, obj) } obj } finally {
blocks.foreach(_.dispose()) } } } } }

readBlocks方法从会随机选择一个远程节点获取数据,这样做的好处是可以避免大量Executor同时从Driver拉取数据而造成的数据热点问题。

private def readBlocks(): Array[BlockData] = {
val blocks = new Array[BlockData](numBlocks) val bm = SparkEnv.get.blockManager // 随机选择一个远程节点拉取数据 for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid) logDebug(s"Reading piece $pieceId of $broadcastId") bm.getLocalBytes(pieceId) match {
case Some(block) => blocks(pid) = block releaseLock(pieceId) case None => bm.getRemoteBytes(pieceId) match {
case Some(b) => if (checksumEnabled) {
val sum = calcChecksum(b.chunks(0)) if (sum != checksums(pid)) {
throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" + s" $sum != ${checksums(pid)}") } } if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException( s"Failed to store $pieceId of $broadcastId in local BlockManager") } blocks(pid) = new ByteBufferBlockData(b, true) case None => throw new SparkException(s"Failed to get $pieceId of $broadcastId") } } } blocks }}

转载地址:http://arcmb.baihongyu.com/

你可能感兴趣的文章
检查表达式中的括号是否匹配
查看>>
一道关于 goroutine 的面试题
查看>>
信号量的使用方法
查看>>
Redis 缓存穿透、击穿、雪崩
查看>>
RabbitMQ(1): docker-compose安装rabbitmq及简单使用Hello World
查看>>
利用序列化实现对象的拷贝
查看>>
is-a,has-a,like-a是什么
查看>>
简单工厂、工厂、抽象工厂的对比
查看>>
J2EE的体系架构——J2EE
查看>>
对于关系型数据库中的索引的基本理解
查看>>
索引,主键,唯一索引,联合索引的区别
查看>>
剪桌腿的最小代价
查看>>
Zookeeper原理架构
查看>>
利用ZooKeeper简单实现分布式锁
查看>>
Lock、ReentrantLock、synchronized
查看>>
Java过滤器与SpringMVC拦截器之间的关系与区别
查看>>
Java中的String为什么是不可变的?
查看>>
剑指offer二叉搜索树与双向链表
查看>>
LeetCode 81. 搜索旋转排序数组 II(头条)
查看>>
LC 42. 接雨水 + LC 11. 盛最多水的容器
查看>>