import System.{exit, nanoTime}
import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.{Column, SparkSession, DataFrame}
import org.apache.spark.sql.functions._
import spark.implicits._
object Main extends Serializable {
val s = 0.03
def loadFakeData() : DataFrame = {
var data = Seq("1 ",
"1 2 ",
"1 2",
"1 2 3 ",
"1 2 ")
.withColumn("baskets", split('baskets_str, " ").cast("array<int>"))
def combo(a1: WrappedArray[Int], a2: WrappedArray[Int]): Array[Array[Int]] = {
var a = a1.toSet
var b = a2.toSet
var res = a.diff(b).map(b+_) ++ b.diff(a).map(a+_)
return < _)).toArray
val comboUDF = udf[Array[Array[Int]], WrappedArray[Int], WrappedArray[Int]](combo)
def getCombinations(df: DataFrame): DataFrame = {
df.crossJoin(df.withColumnRenamed("itemsets", "itemsets_2"))
.withColumn("combinations", comboUDF(col("itemsets"), col("itemsets_2")))
.withColumnRenamed("combinations", "itemsets")
.withColumn("itemsets", explode(col("itemsets")))
def countCombinations(data : DataFrame, combinations: DataFrame) : DataFrame = {
.where(size(array_intersect('baskets, 'itemsets)) === size('itemsets))
def freq() {
val spark = SparkSession.builder.appName("FreqItemsets")
// data is a dataframe where each row contains an array of integer values
var data = loadFakeData()
val basket_count = data.count
// Itemset is a dataframe containing all possible sets of 1 element
var itemset : DataFrame = data
.withColumnRenamed("col", "itemsets")
.withColumn("itemsets", array('itemsets))
var itemset_count : DataFrame = countCombinations(data, itemset).filter('count > s*basket_count)
var itemset_counts = List(itemset_count)
// We iterate creating each time itemsets of length k+1 from itemsets of length k
// pruning those that do not have enough support
var stop = (itemset_count.count == 0)
while(!stop) {
itemset = getCombinations("itemsets"))
itemset_count = countCombinations(data, itemset).filter('count > s*basket_count)
stop = (itemset_count.count == 0)
if (!stop) {
itemset_counts = itemset_counts :+ itemset_count
itemset = getCombinations("itemsets")).cache
itemset_counts = itemset_count :: itemset_counts
