Skip to content

Commit

Permalink
update how_to_make_hf_dataset.md
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawnXuan committed Jun 15, 2021
1 parent d1275e6 commit 395c28f
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions ClickThroughRate/WideDeepLearning/how_to_make_hf_dataset.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# How to Make High Frequency Dataset for OneFlow-WDL

[**how_to_make_ofrecord_for_wdl**](https://github.com/Oneflow-Inc/OneFlow-Benchmark/blob/master/ClickThroughRate/WideDeepLearning/how_to_make_ofrecord_for_wdl.md)一文中介绍了如何利用spark制作OneFlow-WDL使用的ofrecord数据集,GPU&CPU混合embedding的实践中,这个数据集就不好用了,主要原因是没有按照词频排序,所以需要制作新的数据集。本文将持续上文中的套路,介绍一下如何制作按照词频排序的数据集。
[**how_to_make_ofrecord_for_wdl**](https://github.com/Oneflow-Inc/OneFlow-Benchmark/blob/master/ClickThroughRate/WideDeepLearning/how_to_make_ofrecord_for_wdl.md)一文中介绍了如何利用spark制作OneFlow-WDL使用的ofrecord数据集,GPU&CPU混合embedding的实践中需要把特征根据词频从大到小排序,本文将持续上文中的套路,介绍一下如何制作按照词频排序的数据集。

## 数据集及预处理

Expand Down Expand Up @@ -53,14 +53,12 @@ Files.createDirectories(Paths.get(tmp_dir))
```scala
// load input file
var input = spark.read.options(Map("delimiter"->"\t")).csv("file:///DATA/disk1/xuan/train.shuf.bak")
// var input = spark.read.options(Map("delimiter"->"\t")).csv("file:///DATA/disk1/xuan/train.shuf.txt")

// rename columns [label, I1,...,I13, C1,...,C26]
val NUM_INTEGER_COLUMNS = 13
val NUM_CATEGORICAL_COLUMNS = 26

// val integer_cols = (1 to NUM_INTEGER_COLUMNS).map{id=>s"I$id"}
val integer_cols = (1 to NUM_INTEGER_COLUMNS).map{id=>$"I$id"} // note
val integer_cols = (1 to NUM_INTEGER_COLUMNS).map{id=>$"I$id"}
val categorical_cols = (1 to NUM_CATEGORICAL_COLUMNS).map{id=>s"C$id"}
val feature_cols = integer_cols.map{c=>c.toString} ++ categorical_cols
val all_cols = (Seq(s"labels") ++ feature_cols)
Expand Down Expand Up @@ -133,7 +131,6 @@ for(column_name <- integer_cols) {
} else {
scaledDf = indexedDf.withColumn(col_name, col(col_index) + lit(1)) // trick: reuse col_name
.select("id", col_name)
//.withColumn(col_name, col(col_index).cast(IntegerType))
}
val col_dir = tmp_dir ++ "/" ++ col_name
scaledDf = scaledDf.withColumn(col_name, getItem(column_name, lit(0)))
Expand Down Expand Up @@ -233,8 +230,6 @@ for(cross_pair <- cross_pairs) {
val df_col = spark.read.parquet(tmp_dir ++ "/" ++ cross_pair)
df = df.join(df_col, Seq("id"))
}
// df.select("C1_C2", "C3_C4").createOrReplaceTempView("f")
// df.select(cross_pairs.map{id=>col(id)}:_*).createOrReplaceTempView("f")
df.select(cross_pairs map col: _*).createOrReplaceTempView("f")

val orderedValues = spark.sql("select cid, count(*) as cnt from (select explode( array(" + cross_pairs.mkString(",") + ") ) as cid from f) group by cid ").filter("cnt>=6").orderBy($"cnt".desc)
Expand Down

0 comments on commit 395c28f

Please sign in to comment.