sparkmagic
sparkmagic copied to clipboard
error: not found: type UserDefinedAggregateFunction
I want to customize an aggregation function in Jupyter, the code is as follows
import org.apache.spark.sql.functions._
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer,UserDefinedAggregateFunction}
import org.apache.spark.ml.linalg.{Vector, Vectors, SQLDataTypes}
import scala.collection.mutable.WrappedArray
import java.text.SimpleDateFormat
import java.util.Calendar
class VectorSum (n: Int) extends UserDefinedAggregateFunction {
def inputSchema = new StructType().add("v", SQLDataTypes.VectorType)
def bufferSchema = new StructType().add("buff", ArrayType(DoubleType))
def dataType = SQLDataTypes.VectorType
def deterministic = true
def initialize(buffer: MutableAggregationBuffer) = {
buffer.update(0, Array.fill(n)(0.0))
}
def update(buffer: MutableAggregationBuffer, input: Row) = {
if (!input.isNullAt(0)) {
val buff = buffer.getAs[WrappedArray[Double]](0)
val v = input.getAs[Vector](0).toSparse
for (i <- v.indices) {
buff(i) += v(i)
}
buffer.update(0, buff)
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val buff1 = buffer1.getAs[WrappedArray[Double]](0)
val buff2 = buffer2.getAs[WrappedArray[Double]](0)
for ((x, i) <- buff2.zipWithIndex) {
buff1(i) += x
}
buffer1.update(0, buff1)
}
def evaluate(buffer: Row) = Vectors.dense(
buffer.getAs[Seq[Double]](0).toArray)
}
object Utils {
def getTimeString(time: String, delta: Int = -1, format: String = "yyyyMMdd"): String = {
val simpleFormat = new SimpleDateFormat(format)
val cal = Calendar.getInstance()
cal.setTime(simpleFormat.parse(time))
cal.add(Calendar.DATE, delta)
simpleFormat.format(cal.getTime)
}
def getTodayString(): String = {
val cal = Calendar.getInstance()
val today = new SimpleDateFormat("yyyyMMdd").format(cal.getTime)
today
}
def getYesterdayString(): String = {
getTimeString(getTodayString())
}
}
but with the following error,
<console>:14: error: not found: type UserDefinedAggregateFunction
class VectorSum (n: Int) extends UserDefinedAggregateFunction {
^
<console>:15: error: not found: type StructType
def inputSchema = new StructType().add("v", SQLDataTypes.VectorType)
^
<console>:15: error: not found: value SQLDataTypes
def inputSchema = new StructType().add("v", SQLDataTypes.VectorType)
^
<console>:16: error: not found: type StructType
def bufferSchema = new StructType().add("buff", ArrayType(DoubleType))
^
<console>:16: error: not found: value ArrayType
def bufferSchema = new StructType().add("buff", ArrayType(DoubleType))
^
<console>:16: error: not found: value DoubleType
def bufferSchema = new StructType().add("buff", ArrayType(DoubleType))
^
<console>:17: error: not found: value SQLDataTypes
def dataType = SQLDataTypes.VectorType
^
<console>:20: error: not found: type MutableAggregationBuffer
def initialize(buffer: MutableAggregationBuffer) = {
^
<console>:24: error: not found: type Row
def update(buffer: MutableAggregationBuffer, input: Row) = {
^
<console>:24: error: not found: type MutableAggregationBuffer
def update(buffer: MutableAggregationBuffer, input: Row) = {
^
<console>:26: error: not found: type WrappedArray
val buff = buffer.getAs[WrappedArray[Double]](0)
^
<console>:35: error: not found: type MutableAggregationBuffer
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
^
<console>:36: error: not found: type WrappedArray
val buff1 = buffer1.getAs[WrappedArray[Double]](0)
^
<console>:35: error: not found: type Row
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
^
<console>:37: error: not found: type WrappedArray
val buff2 = buffer2.getAs[WrappedArray[Double]](0)
^
<console>:44: error: not found: value Vectors
def evaluate(buffer: Row) = Vectors.dense(
^
<console>:44: error: not found: type Row
def evaluate(buffer: Row) = Vectors.dense(
^
But it doesn't report an error when I import the following package, import org.apache.spark.sql.expressions.{MutableAggregationBuffer,UserDefinedAggregateFunction}
I met the same error like yours. Do you find the reason?
I think I found a way to workaround this bug. You need to complete the name of the method called, like replace UserDefinedAggregateFunction
with org.apache.spark.sql.expressions.UserDefinedAggregateFunction
.
FYI, I found this method at this mail list.