今回は pyspark.sql.DataFrame クラスの主要なメソッドを備忘録用にまとめてみました。 環境は macOS 10.13.3, Apache Spark 2.3.0 です。
PySparkの起動
Spark は Download Apache Spark から DL できる。
$ export PYSPARK_DRIVER_PYTHON=ipython
$ pyspark
Python 2.7.11 |Anaconda custom (x86_64)| (default, Jun 15 2016, 16:09:16)
Type "copyright", "credits" or "license" for more information.
IPython 4.2.0 -- An enhanced Interactive Python.
? -> Introduction and overview of IPython's features.
%quickref -> Quick reference.
help -> Python's own help system.
object? -> Details about 'object', use 'object??' for extra details.
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/__ / .__/\_,_/_/ /_/\_\ version 2.3.0
/_/
Using Python version 2.7.11 (default, Jun 15 2016 16:09:16)
SparkSession available as 'spark'.
In [1]: sc
Out[1]: <SparkContext master=local[*] appName=PySparkShell>
In [2]: df = spark.createDataFrame(
...: [(1, 144.5, 5.9, 33, 'M'), (2, 167.2, 5.4, 45, 'M'), (3, 124.1, 5.2, 23, 'F'),
...: (4, 144.5, 5.9, 33, 'M'), (5, 133.2, 5.7, 54, 'F'), (3, 124.1, 5.2, 23, 'F'),
...: (6, 149.3, None, 54, 'M'),],
...: ['id', 'weight', 'height', 'age', 'gender'])
1. データ構造の確認
DataFrame (Dataset Untyped API) は RDD (Resilient Distributed Dataset) より抽象化された構造を持っている。
columns は全ての列名を list で返す。
In [3]: df.columns
Out[3]: ['id', 'weight', 'height', 'age', 'gender']
withColumnRenamed(existing, new) は既に存在する列名をリネームし, 新しい DataFrame を返す。
In [4]: df.withColumnRenamed('id', 'index').show()
Out[4]:
+-----+------+------+---+------+
|index|weight|height|age|gender|
+-----+------+------+---+------+
| 1| 144.5| 5.9| 33| M|
| 2| 167.2| 5.4| 45| M|
| 3| 124.1| 5.2| 23| F|
| 4| 144.5| 5.9| 33| M|
| 5| 133.2| 5.7| 54| F|
| 3| 124.1| 5.2| 23| F|
| 6| 149.3| null| 54| M|
+-----+------+------+---+------+
dtypes は全ての列名とその型の list を返す。
In [5]: df.dtypes
Out[5]:
[('id', 'bigint'),
('weight', 'double'),
('height', 'double'),
('age', 'bigint'),
('gender', 'string')]
printSchema() は DataFrame の schema を木構造で出力する。
In [6]: df.printSchema()
root
|-- id: long (nullable = true)
|-- weight: double (nullable = true)
|-- height: double (nullable = true)
|-- age: long (nullable = true)
|-- gender: string (nullable = true)
head(n=None) は先頭 n 行を Row の list で返す。
In [7]: df.head()
Out[7]: Row(id=1, weight=144.5, height=5.9, age=33, gender=u'M')
count() はレコード数を返す。
In [8]: df.count()
Out[8]: 7
collect() は全てのレコードを Row の list で返す。
In [9]: df.collect()
Out[9]:
[Row(id=1, weight=144.5, height=5.9, age=33, gender=u'M'),
Row(id=2, weight=167.2, height=5.4, age=45, gender=u'M'),
Row(id=3, weight=124.1, height=5.2, age=23, gender=u'F'),
Row(id=4, weight=144.5, height=5.9, age=33, gender=u'M'),
Row(id=5, weight=133.2, height=5.7, age=54, gender=u'F'),
Row(id=3, weight=124.1, height=5.2, age=23, gender=u'F'),
Row(id=6, weight=149.3, height=None, age=54, gender=u'M')]
show(n=20, truncate=True, vertical=False) は DataFrame の先頭 n 行を出力する。
In [10]: df.show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 1| 144.5| 5.9| 33| M|
| 2| 167.2| 5.4| 45| M|
| 3| 124.1| 5.2| 23| F|
| 4| 144.5| 5.9| 33| M|
| 5| 133.2| 5.7| 54| F|
| 3| 124.1| 5.2| 23| F|
| 6| 149.3| null| 54| M|
+---+------+------+---+------+
crosstab(col1, col2) は指定した2変数の分割表を返す。
In [11]: df.crosstab("age", "gender").show()
+----------+---+---+
|age_gender| F| M|
+----------+---+---+
| 23| 2| 0|
| 45| 0| 1|
| 54| 1| 1|
| 33| 0| 2|
+----------+---+---+
2. 射影・抽出
select(*cols) は列名または式の射影を行い, 結果を DataFrame で返す。
In [12]: df.select('id', 'gender').show()
+---+------+
| id|gender|
+---+------+
| 1| M|
| 2| M|
| 3| F|
| 4| M|
| 5| F|
| 3| F|
| 6| M|
+---+------+
filter(condition) は指定した condition で行をフィルタリングする。 where() は filter() のエイリアス。
In [13]: df.filter(df.weight > 145.0).show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 2| 167.2| 5.4| 45| M|
| 6| 149.3| null| 54| M|
+---+------+------+---+------+
In [14]: df.where(df.gender == 'M').show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 1| 144.5| 5.9| 33| M|
| 2| 167.2| 5.4| 45| M|
| 4| 144.5| 5.9| 33| M|
| 6| 149.3| null| 54| M|
+---+------+------+---+------+
複数の condition を指定することもでき, 論理積は &, 論理和は |, 論理否定は ~ 演算子を使う。
In [15]: df.filter((df.weight > 145.0) | (df.gender == 'M')).show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 1| 144.5| 5.9| 33| M|
| 2| 167.2| 5.4| 45| M|
| 4| 144.5| 5.9| 33| M|
| 6| 149.3| null| 54| M|
+---+------+------+---+------+
In [16]: from pyspark.sql.functions import col
In [17]: df.filter(~(col("weight") > 145.0) & (col("gender") == 'M')).show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 1| 144.5| 5.9| 33| M|
| 4| 144.5| 5.9| 33| M|
+---+------+------+---+------+
drop(*cols) は指定した列を削除した新しい DataFrame を返す。
In [18]: df.drop("age").show()
+---+------+------+------+
| id|weight|height|gender|
+---+------+------+------+
| 1| 144.5| 5.9| M|
| 2| 167.2| 5.4| M|
| 3| 124.1| 5.2| F|
| 4| 144.5| 5.9| M|
| 5| 133.2| 5.7| F|
| 3| 124.1| 5.2| F|
| 6| 149.3| null| M|
+---+------+------+------+
3. 要約統計量
summary(*statistics) は要約統計量 (count, mean, stddev, min, max, percentiles) を返す。
In [19]: df.select('weight', 'height', 'age').summary().show()
+-------+------------------+-------------------+------------------+
|summary| weight| height| age|
+-------+------------------+-------------------+------------------+
| count| 7| 6| 7|
| mean| 140.9857142857143| 5.550000000000001|37.857142857142854|
| stddev|15.339972682660223|0.32710854467592254| 13.29697423512296|
| min| 124.1| 5.2| 23|
| 25%| 124.1| 5.2| 23|
| 50%| 144.5| 5.4| 33|
| 75%| 149.3| 5.9| 54|
| max| 167.2| 5.9| 54|
+-------+------------------+-------------------+------------------+
describe(*cols) は指定した列の要約統計量を返す。
In [20]: df.describe(['weight', 'height', 'age']).show()
+-------+------------------+-------------------+------------------+
|summary| weight| height| age|
+-------+------------------+-------------------+------------------+
| count| 7| 6| 7|
| mean| 140.9857142857143| 5.550000000000001|37.857142857142854|
| stddev|15.339972682660223|0.32710854467592254| 13.29697423512296|
| min| 124.1| 5.2| 23|
| max| 167.2| 5.9| 54|
+-------+------------------+-------------------+------------------+
corr(col1, col2, method=None) は2変数の相関を double で返す。 Spark 2.3 時点ではピアソンの相関係数のみサポートしている。
In [21]: df.corr("weight", "height")
Out[21]: -0.1895423855498912
cov(col1, col2) は指定した2変数の標本共分散を double で返す。
In [22]: df.cov("weight", "height")
Out[22]: -6.160714285714281
4. 結合
join(other, on=None, how=None) は2つの DataFrame を結合する。
on には結合条件を指定し, how には結合方法 (inner, cross, outer, full, full_outer, left, left_outer, right, right_outer, left_semi, and left_ant) を指定する。
In [23]: df2 = spark.createDataFrame([(1, 'Tokyo',), (2, 'Kyoto',), ], ['id', 'from'])
In [24]: df2.join(df, df2.id == df.id, 'left').select(df.id, df.gender, df2['from']).show()
+---+------+-----+
| id|gender| from|
+---+------+-----+
| 1| M|Tokyo|
| 2| M|Kyoto|
+---+------+-----+
crossJoin(other) は2つの DataFrame の直積を返す。
In [25]: df.crossJoin(df2.select("from")).select("age", "gender", "from").show()
+---+------+-----+
|age|gender| from|
+---+------+-----+
| 33| M|Tokyo|
| 33| M|Kyoto|
| 45| M|Tokyo|
| 23| F|Tokyo|
| 45| M|Kyoto|
| 23| F|Kyoto|
| 33| M|Tokyo|
| 54| F|Tokyo|
| 33| M|Kyoto|
| 54| F|Kyoto|
| 23| F|Tokyo|
| 54| M|Tokyo|
| 23| F|Kyoto|
| 54| M|Kyoto|
+---+------+-----+
5. 統合 (連結)
union(other) は2つの DataFrame を統合し新しい DataFrame を返す。
In [26]: df3 = spark.createDataFrame([(7, 122.7, 5.1, 36, 'F')],
....: ['id', 'weight', 'height', 'age', 'gender'])
In [27]: df.union(df3).show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 1| 144.5| 5.9| 33| M|
| 2| 167.2| 5.4| 45| M|
| 3| 124.1| 5.2| 23| F|
| 4| 144.5| 5.9| 33| M|
| 5| 133.2| 5.7| 54| F|
| 3| 124.1| 5.2| 23| F|
| 6| 149.3| null| 54| M|
| 7| 122.7| 5.1| 36| F|
+---+------+------+---+------+
6. グループ化・集約
groupBy(*cols) は指定した列でグループ化し, pyspark.sql.GroupedData を返す。 GroupedData に対して集約関数を適用する。
In [28]: df.select("height").groupBy().mean().collect()
Out[28]: [Row(avg(height)=5.550000000000001)]
agg(*exprs) は df.groupBy.agg() の省略形。
In [29]: df.agg({"age": "mean"}).collect()
Out[29]: [Row(avg(age)=37.857142857142854)]
7. 欠測値の確認・削除・補完
欠測値は pyspark.sql.Column クラスの isNull() で確認できる。
In [30]: from pyspark.sql.functions import col
In [31]: df.filter(col("height").isNull()).show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 6| 149.3| null| 54| M|
+---+------+------+---+------+
In [32]: df.filter(col("height").isNotNull()).show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 1| 144.5| 5.9| 33| M|
| 2| 167.2| 5.4| 45| M|
| 3| 124.1| 5.2| 23| F|
| 4| 144.5| 5.9| 33| M|
| 5| 133.2| 5.7| 54| F|
| 3| 124.1| 5.2| 23| F|
+---+------+------+---+------+
dropna(how=’any’, thresh=None, subset=None) は欠測値を含む行を削除した新しい DataFrame を返す。thresh には削除の閾値, subset に対象の列名を指定することができる。
In [33]: df.dropna().show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 1| 144.5| 5.9| 33| M|
| 2| 167.2| 5.4| 45| M|
| 3| 124.1| 5.2| 23| F|
| 4| 144.5| 5.9| 33| M|
| 5| 133.2| 5.7| 54| F|
| 3| 124.1| 5.2| 23| F|
+---+------+------+---+------+
fillna(value, subset=None) は指定した値で欠測値を補完する。subset に対象の列名を指定することができる。
In [34]: height_mean = round(df.select("height").groupBy().avg().head()[0], 1)
In [35]: df.fillna(height_mean).show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 1| 144.5| 5.9| 33| M|
| 2| 167.2| 5.4| 45| M|
| 3| 124.1| 5.2| 23| F|
| 4| 144.5| 5.9| 33| M|
| 5| 133.2| 5.7| 54| F|
| 3| 124.1| 5.2| 23| F|
| 6| 149.3| 5.6| 54| M|
+---+------+------+---+------+
8. 重複値の削除
distinct() は異なる値を持つ行で構成される新しい DataFrame を返す。 (重複行の確認)
In [36]: df.select(["id"]).distinct().count()
Out[35]: 6
dropDuplicates(subset=None) は重複値を削除した新しい DataFrame を返す。subset に対象の列名を指定することができる。
In [37]: df.dropDuplicates(['id']).show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 6| 149.3| null| 54| M|
| 5| 133.2| 5.7| 54| F|
| 1| 144.5| 5.9| 33| M|
| 3| 124.1| 5.2| 23| F|
| 2| 167.2| 5.4| 45| M|
| 4| 144.5| 5.9| 33| M|
+---+------+------+---+------+
9. 置換・ソート
replace(to_replace, value=, subset=None) は to_replace に指定した値を value で置換する。
In [38]: df.replace(['M', 'F'], [None, None], 'gender').show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 1| 144.5| 5.9| 33| null|
| 2| 167.2| 5.4| 45| null|
| 3| 124.1| 5.2| 23| null|
| 4| 144.5| 5.9| 33| null|
| 5| 133.2| 5.7| 54| null|
| 3| 124.1| 5.2| 23| null|
| 6| 149.3| null| 54| null|
+---+------+------+---+------+
orderBy(*cols, **kwargs) は指定した列でソートし新しい DataFrame を返す。
In [39]: from pyspark.sql.functions import desc
In [40]: df.orderBy(desc("age"), "id").show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 5| 133.2| 5.7| 54| F|
| 6| 149.3| null| 54| M|
| 2| 167.2| 5.4| 45| M|
| 1| 144.5| 5.9| 33| M|
| 4| 144.5| 5.9| 33| M|
| 3| 124.1| 5.2| 23| F|
| 3| 124.1| 5.2| 23| F|
+---+------+------+---+------+
sort(*cols, **kwargs) も指定した列でソートし新しい DataFrame を返す。
In [41]: df.sort(df.age.desc()).show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 5| 133.2| 5.7| 54| F|
| 6| 149.3| null| 54| M|
| 2| 167.2| 5.4| 45| M|
| 4| 144.5| 5.9| 33| M|
| 1| 144.5| 5.9| 33| M|
| 3| 124.1| 5.2| 23| F|
| 3| 124.1| 5.2| 23| F|
+---+------+------+---+------+
10. サンプリング
sample(withReplacement=None, fraction=None, seed=None) はサンプリングした DataFrame のサブセットを返す。
withReplacement には復元抽出の有無 (default は False), fraction にはサンプリングの割合 (0.0 ~ 1.0) を指定する。
In [41]: df.sample(False, fraction=0.3).show()
+---+------+------+---+------+
| id|weight|height|age|gender|
+---+------+------+---+------+
| 2| 167.2| 5.4| 45| M|
| 3| 124.1| 5.2| 23| F|
| 5| 133.2| 5.7| 54| F|
+---+------+------+---+------+
11. データ形式の変換
toJSON(use_unicode=True) は DataFrame の各行を JSON 文字列に変換した list を返す。
In [42]: df.toJSON().first()
Out[42]: u'{"id":1,"weight":144.5,"height":5.9,"age":33,"gender":"M"}'
toPandas() は DataFrame を pandas.DataFrame に変換する。ただし, RDD に対する collect() の操作になるので大きなサイズの DataFrame を扱う場合はドライバのメモリに注意する必要がある。
In [43]: df.toPandas()
Out[43]:
id weight height age gender
0 1 144.5 5.9 33 M
1 2 167.2 5.4 45 M
2 3 124.1 5.2 23 F
3 4 144.5 5.9 33 M
4 5 133.2 5.7 54 F
5 3 124.1 5.2 23 F
6 6 149.3 NaN 54 M
12. 日時の変換
以降, pyspark.sql.functions クラスのメソッドとなるが, よく行う処理なのでメモしておく。
unix_timestamp(timestamp=None, format=’yyyy-MM-dd HH:mm:ss’) は日時を表す文字列を UNIX タイムスタンプに変換する。 ISO 8601形式の日時の場合は format=”yyyy-MM-dd’T’HH:mm:ss” とする。
In [44]: df_ts = spark.createDataFrame([(1, "2018-04-29 16:09:16"), (2, "2018-04-29 18:01:32")], ("id", "dt_str"))
In [45]: from pyspark.sql.functions import udf, unix_timestamp, to_timestamp
In [46]: df_ts.withColumn("ts", unix_timestamp("dt_str")).show()
+---+-------------------+----------+
| id| dt_str| ts|
+---+-------------------+----------+
| 1|2018-04-29 16:09:16|1524985756|
| 2|2018-04-29 18:01:32|1524992492|
+---+-------------------+----------+
to_timestamp(col, format=None) は pyspark.sql.types.StringType または pyspark.sql.types.TimestampType を pyspark.sql.types.DateType に変換する。
In [47]: df_ts.withColumn("dt", to_timestamp("dt_str")).collect()
Out[47]:
[Row(id=1, dt_str=u'2018-04-29 16:09:16', dt=datetime.datetime(2018, 4, 29, 16, 9, 16)),
Row(id=2, dt_str=u'2018-04-29 18:01:32', dt=datetime.datetime(2018, 4, 29, 18, 1, 32))]
In [48]: hour = udf(lambda x: x.hour)
In [49]: df_ts.withColumn("dt", to_timestamp("dt_str")).withColumn('hour', hour(col('dt'))).show()
+---+-------------------+-------------------+----+
| id| dt_str| dt|hour|
+---+-------------------+-------------------+----+
| 1|2018-04-29 16:09:16|2018-04-29 16:09:16| 16|
| 2|2018-04-29 18:01:32|2018-04-29 18:01:32| 18|
+---+-------------------+-------------------+----+
おわりに
ML DataFrame-based API の pipeline を上手く活用するためにも DataFrame の操作には慣れておきたいですね。 参考書籍は『入門PySpark』です。
ちなみに PySpark を Jupyter Notebook で使いたい場合は, Jupyter Notebook がインストールされている状態で以下の環境変数を設定してから pyspark を起動します。
$ export PYSPARK_DRIVER_PYTHON=jupyter
$ export PYSPARK_DRIVER_PYTHON_OPTS='notebook'
$ pyspark
本題から逸れますが, Jupyter Notebook の拡張機能 jupyter_contrib_nbextensions が便利でした。インストール方法は以下です。
$ pip install jupyter_contrib_nbextensions
$ jupyter contrib nbextension install --user
Jupyter Notebook を起動すると Nbextensions タブが追加されており, 利用したい機能にチェックを入れます。個人的に文書を書くときに左側にナビゲーションを表示する ToC (Table of Contents) が気に入っています。
[1] Get Started with PySpark and Jupyter Notebook in 3 Minutes
[2] Pyspark: multiple conditions in when clause