General User-defined Functions
This documentation is for an unreleased version of Apache Flink. We recommend you use the latest stable version.

一般的なユーザ定義関数: #

ユーザ定義関数は、Python Table APIプログラムの表現力を大幅に拡張するため、重要な機能です。

注意: Python UDFの実行には、PyFlink がインストールされたPythonバージョン(3.8、3.9、3.10)が必要です。これはクライアント側とクラスタ側の両方で必要です。

Scalar 関数 #

Python Table APIプログラムでのPythonスカラー関数の使用をサポートします。Pythonスカラー関数を定義するために、pyflink.table.udfで基本クラスScalarFunctionを拡張し、評価メソッドを実装できます。 Pythonスカラー関数の挙動は、evalという評価メソッドで定義されます。 評価メソッドはeval(*args)のような可変変数をサポートできます。

以下の例は、独自のPythonハッシュコード関数を定義し、それをTableEnvironmentに登録し、クエリ内でそれを呼び出す方法を示しています。 登録する前にコンストラクタを介してスカラー関数を設定できることに注意してください:

from pyflink.table.expressions import call, col
from pyflink.table import DataTypes, TableEnvironment, EnvironmentSettings
from pyflink.table.udf import ScalarFunction, udf

class HashCode(ScalarFunction):
  def __init__(self):
    self.factor = 12

  def eval(self, s):
    return hash(s) * self.factor

settings = EnvironmentSettings.in_batch_mode()
table_env = TableEnvironment.create(settings)

hash_code = udf(HashCode(), result_type='BIGINT')

# use the Python function in Python Table API
my_table.select(col("string"), col("bigint"), hash_code(col("bigint")), call(hash_code, col("bigint")))

# use the Python function in SQL API
table_env.create_temporary_function("hash_code", udf(HashCode(), result_type='BIGINT'))
table_env.sql_query("SELECT string, bigint, hash_code(bigint) FROM MyTable")

Python Table APIプログラムでJava/Scalaスカラー関数を使うこともサポートします。

'''
Java code:

// The Java class must have a public no-argument constructor and can be founded in current Java classloader.
public class HashCode extends ScalarFunction {
  private int factor = 12;

  public int eval(String s) {
      return s.hashCode() * factor;
  }
}
'''
from pyflink.table.expressions import call, col
from pyflink.table import TableEnvironment, EnvironmentSettings

settings = EnvironmentSettings.in_batch_mode()
table_env = TableEnvironment.create(settings)

# register the Java function
table_env.create_java_temporary_function("hash_code", "my.java.function.HashCode")

# use the Java function in Python Table API
my_table.select(call('hash_code', col("string")))

# use the Java function in SQL API
table_env.sql_query("SELECT string, bigint, hash_code(string) FROM MyTable")

基本クラスScalarFunctionを拡張する以外にも、Pythonスカラー関数を定義する方法はたくさんあります。 以下の例は、入力パラメータとしてbigintの2つのカラムを受け取り、それらの合計を結果として返すPythonスカラー関数を定義する様々な方法を示しています。

# option 1: extending the base class `ScalarFunction`
class Add(ScalarFunction):
  def eval(self, i, j):
    return i + j

add = udf(Add(), result_type=DataTypes.BIGINT())

# option 2: Python function
@udf(result_type='BIGINT')
def add(i, j):
  return i + j

# option 3: lambda function
add = udf(lambda i, j: i + j, result_type='BIGINT')

# option 4: callable function
class CallableAdd(object):
  def __call__(self, i, j):
    return i + j

add = udf(CallableAdd(), result_type='BIGINT')

# option 5: partial function
def partial_add(i, j, k):
  return i + j + k

add = udf(functools.partial(partial_add, k=1), result_type='BIGINT')

# register the Python function
table_env.create_temporary_function("add", add)
# use the function in Python Table API
my_table.select(call('add', col('a'), col('b')))

# You can also use the Python function in Python Table API directly
my_table.select(add(col('a'), col('b')))

Table関数 #

Pythonのユーザ定義スカラー関数と同様に、ユーザ定義のtable関数は0個以上のスカラー値を入力パラメータとして受け取ります。ただし、スカラー関数とは対照的に、1つの値ではなく、任意の数の行を出力として返すことができます。Python UDFの返り値の型は、Iterable、Iterator、generatorの型になります。

以下の例あh、独自のPythonマルチ発行関数を定義し、それをTableEnvironmentに登録し、クエリ内でそれを呼び出す方法を示しています。

from pyflink.table.expressions import col
from pyflink.table import DataTypes, TableEnvironment, EnvironmentSettings
from pyflink.table.udf import TableFunction, udtf

class Split(TableFunction):
    def eval(self, string):
        for s in string.split(" "):
            yield s, len(s)

env_settings = EnvironmentSettings.in_streaming_mode()
table_env = TableEnvironment.create(env_settings)
my_table = ...  # type: Table, table schema: [a: String]

# register the Python Table Function
split = udtf(Split(), result_types=['STRING', 'INT'])

# use the Python Table Function in Python Table API
my_table.join_lateral(split(col("a")).alias("word", "length"))
my_table.left_outer_join_lateral(split(col("a")).alias("word", "length"))

# use the Python Table function in SQL API
table_env.create_temporary_function("split", udtf(Split(), result_types=['STRING', 'INT']))
table_env.sql_query("SELECT a, word, length FROM MyTable, LATERAL TABLE(split(a)) as T(word, length)")
table_env.sql_query("SELECT a, word, length FROM MyTable LEFT JOIN LATERAL TABLE(split(a)) as T(word, length) ON TRUE")

Python Table APIプログラムでJava/Scalaのtable関数を使うこともサポートします。

'''
Java code:

// The generic type "Tuple2<String, Integer>" determines the schema of the returned table as (String, Integer).
// The java class must have a public no-argument constructor and can be founded in current java classloader.
public class Split extends TableFunction<Tuple2<String, Integer>> {
    private String separator = " ";
    
    public void eval(String str) {
        for (String s : str.split(separator)) {
            // use collect(...) to emit a row
            collect(new Tuple2<String, Integer>(s, s.length()));
        }
    }
}
'''
from pyflink.table.expressions import call, col
from pyflink.table import TableEnvironment, EnvironmentSettings

env_settings = EnvironmentSettings.in_streaming_mode()
table_env = TableEnvironment.create(env_settings)
my_table = ...  # type: Table, table schema: [a: String]

# Register the java function.
table_env.create_java_temporary_function("split", "my.java.function.Split")

# Use the table function in the Python Table API. "alias" specifies the field names of the table.
my_table.join_lateral(call('split', col('a')).alias("word", "length")).select(col('a'), col('word'), col('length'))
my_table.left_outer_join_lateral(call('split', col('a')).alias("word", "length")).select(col('a'), col('word'), col('length'))

# Register the python function.

# Use the table function in SQL with LATERAL and TABLE keywords.
# CROSS JOIN a table function (equivalent to "join" in Table API).
table_env.sql_query("SELECT a, word, length FROM MyTable, LATERAL TABLE(split(a)) as T(word, length)")
# LEFT JOIN a table function (equivalent to "left_outer_join" in Table API).
table_env.sql_query("SELECT a, word, length FROM MyTable LEFT JOIN LATERAL TABLE(split(a)) as T(word, length) ON TRUE")

Pythonスカラー関数のように、上の5つの方法を使ってPython TableFunctionを定義することができます。

注意 唯一の違いは、Python Table関数の返り値の型はiterable、iterator、generatorである必要があるということです。

# option 1: generator function
@udtf(result_types='BIGINT')
def generator_func(x):
      yield 1
      yield 2

# option 2: return iterator
@udtf(result_types='BIGINT')
def iterator_func(x):
      return range(5)

# option 3: return iterable
@udtf(result_types='BIGINT')
def iterable_func(x):
      result = [1, 2, 3]
      return result

集約関数 #

ユーザ定義の集計関数(UDAGG)は複数の行のスカラー値を新しいスカラー値にマップします。

注意: 現在、一般的なユーザ定義集計関数は、ストリーミングモードのGroupBy集計とGroup Window集計でのみサポートされています。バッチモードの場合は現在さぽーとされていないため、ベクトル化集計関数を使うことをお勧めします。

集計関数の挙動は、accumulatorの概念を中心にしています。_accumulator_は、最終的な集計結果が計算されるまで集計された値を保存する中間データ構造です。

集計する必要がある行のセットごとに、ランタイムはcreate_accumulator()を呼び出して空のaccumulatorを作成します。その後、集計関数のaccumulate(...)メソッドが入力行ごとに呼び出され、accumulatorが更新されます。現在、各行が処理された後で、集計関数のget_value(...)メソッドが呼び出され、集計された結果を計算します。

以下の例は、集計処理を示しています:

UDAGG mechanism

上の例では、飲料に関するデータが含まれるテーブルを想定しています。tableは3つの列(idnameprice)と5行から構成されています。table内の全ての飲料の最高価格を見つけたいと考えています。つまり、max()集計を実行します。

集計関数を定義するには、pyflink.tableの基本クラスAggregateFunctionを拡張し、accumulate(...)という名前の評価関数を実装する必要があります。 集計関数の結果の型とaccumulatorの型は、次の2つの方法のいずれかで指定できます:

  • get_result_type()get_accumulator_type()という名前のメソッドを実装します。
  • 関数のインスタンスを pyflink.table.udf内のデコレータudafでラップし、パラメータresult_typeaccumulator_typeを指定します。

以下の例は、独自の集計関数を定義し、クエリで呼び出す方法を示しています。

from pyflink.common import Row
from pyflink.table import AggregateFunction, DataTypes, TableEnvironment, EnvironmentSettings
from pyflink.table.expressions import call
from pyflink.table.udf import udaf
from pyflink.table.expressions import col, lit
from pyflink.table.window import Tumble


class WeightedAvg(AggregateFunction):

    def create_accumulator(self):
        # Row(sum, count)
        return Row(0, 0)

    def get_value(self, accumulator):
        if accumulator[1] == 0:
            return None
        else:
            return accumulator[0] / accumulator[1]

    def accumulate(self, accumulator, value, weight):
        accumulator[0] += value * weight
        accumulator[1] += weight
    
    def retract(self, accumulator, value, weight):
        accumulator[0] -= value * weight
        accumulator[1] -= weight
        
    def get_result_type(self):
        return 'BIGINT'
        
    def get_accumulator_type(self):
        return 'ROW<f0 BIGINT, f1 BIGINT>'


env_settings = EnvironmentSettings.in_streaming_mode()
table_env = TableEnvironment.create(env_settings)
# the result type and accumulator type can also be specified in the udaf decorator:
# weighted_avg = udaf(WeightedAvg(), result_type=DataTypes.BIGINT(), accumulator_type=...)
weighted_avg = udaf(WeightedAvg())
t = table_env.from_elements([(1, 2, "Lee"),
                             (3, 4, "Jay"),
                             (5, 6, "Jay"),
                             (7, 8, "Lee")]).alias("value", "count", "name")

# call function "inline" without registration in Table API
result = t.group_by(col("name")).select(weighted_avg(col("value"), col("count")).alias("avg")).execute()
result.print()

# register function
table_env.create_temporary_function("weighted_avg", WeightedAvg())

# call registered function in Table API
result = t.group_by(col("name")).select(call("weighted_avg", col("value"), col("count")).alias("avg")).execute()
result.print()

# register table
table_env.create_temporary_view("source", t)

# call registered function in SQL
result = table_env.sql_query(
    "SELECT weighted_avg(`value`, `count`) AS avg FROM source GROUP BY name").execute()
result.print()

# use the general Python aggregate function in GroupBy Window Aggregation
tumble_window = Tumble.over(lit(1).hours) \
            .on(col("rowtime")) \
            .alias("w")

result = t.window(tumble_window) \
        .group_by(col('w'), col('name')) \
        .select(col('w').start, col('w').end, weighted_avg(col('value'), col('count'))) \
        .execute()
result.print()

WeightedAvgクラスのaccumulate(...)メソッドは3つの入力引数を受け付けます。1つ目はaccumulatorで、ほかの2つはユーザ定義の入力です。加重平均値を計算するには、accumulatorは既に蓄積されている全てのデータの加重合計とカウントを保持する必要があります。例では、accumulatorとしてRowオブジェクトを使います。accumulatorはFlinkのチェックポイントの仕組みで管理され、フェイルオーバー時には復元されて、確実に1回のセマンティクスを保証します。

必須のメソッドとオプションのメソッド #

以下のメソッドは各AggregateFunctionで必須です:

  • create_accumulator()
  • accumulate(...)
  • get_value(...)

以下のAggregateFunctionメソッドはユースケースに応じて必要になります:

  • retract(...)は、グループ集計、外部結合などの現在の集計オペレーションの前に、撤回メッセージを生成する可能性のあるオペレーションがある場合に必要です。 このメソッドはオプションですが、UDAFがどのようなユースケースでも使うことができるようにするために実装することを強くお勧めします。
  • merge(...)は、セッションウィンドウとホップウィンドウの集計に必要です。
  • 結果の型とaccumulatorの型がudafデコレータ内で指定されない場合、get_result_type()get_accumulator_type()が必要です。

ListViewとMapView #

accumulatorが大量のデータを保存する必要がある場合は、リストと辞書の代わりにpyflink.table.ListViewpyflink.table.MapViewを使えます。これら2つのデータ構造は、リストと辞書と同様の機能を提供しますが、通常はFlinkの状態バックエンドを利用して不必要な状態アクセスを排除することでパフォーマンスが向上します。 acumulatorの型でDataTypes.LIST_VIEW(...)DataTypes.MAP_VIEW(...)を宣言することで、それらを使うことができます:

from pyflink.table import ListView

class ListViewConcatAggregateFunction(AggregateFunction):

    def get_value(self, accumulator):
        # the ListView is iterable
        return accumulator[1].join(accumulator[0])

    def create_accumulator(self):
        return Row(ListView(), '')

    def accumulate(self, accumulator, *args):
        accumulator[1] = args[1]
        # the ListView support add, clear and iterate operations.
        accumulator[0].add(args[0])

    def get_accumulator_type(self):
        return DataTypes.ROW([
            # declare the first column of the accumulator as a string ListView.
            DataTypes.FIELD("f0", DataTypes.LIST_VIEW(DataTypes.STRING())),
            DataTypes.FIELD("f1", DataTypes.BIGINT())])

    def get_result_type(self):
        return DataTypes.STRING()

現在、ListViewとMapViewを使うには以下の2つの制限があります:

  1. accumulatorはRowである必要があります。
  2. ListViewMapViewRow accumulatorの最初のレベルの子である必要があります。

この高度な機能の詳細については、 documentation of the corresponding classes を参照してください。

注意: Flink状態(accumulatorやデータビューなど)のデータへのアクセスによって発生するPython UDFワーカーとJavaプロセス間のデータ通信コストを削減するために、raw状態ハンドラとPython状態バックエンドの間にはキャッシュされたレイヤがあります。これらの設定オプションを調整して、最高のパフォーマンスが得られるようにキャッシュレイヤーの挙動を変更できます: python.state.cache-sizepython.map-state.read-cache-sizepython.map-state.write-cache-sizepython.map-state.iterate-response-batch-size。 詳細については、Python設定ドキュメントを参照してください。

Table集計関数 #

ユーザ定義table集計関数(UDTAGG)は複数の行のスカラー値を0以上の行(あるいは構造型)にマップします。 返り値のレコードは1つ以上のフィールドで構成される場合があります。出力レコードが1つのフィールドで構成される場合、構造化レコードを省略でき、ランタイムによって暗黙的に行にラップされるスカラー値を発行できます。

注意E: 現在、一般的なユーザ定義table集計関数は、ストリーミングモードのGroupBy集計でのみサポートされています。

集計関数と同様に、table集計の挙動はaccumulatorの概念を中心にしています。 accumulatorは、最終的な集計結果が計算されるまで集計された値を保存する中間データ構造です。

集計する必要がある行のセットごとに、ランタイムはcreate_accumulator()を呼び出して空のaccumulatorを作成します。その後、関数のaccumulate(...)メソッドが入力行ごとに呼び出され、accumulatorが更新されます。全ての行が処理されると、関数のemit_value(...)メソッドが呼び出され、最終的な結果が計算されて返されます。

以下の例は、集計処理を示しています:

UDTAGGの仕組み

例では、飲料に関するデータが含まれるtableを想定しています。tableは3つの列(idnameprice)と5行から構成されています。table内の全ての飲料の最も高い2つの価格を見つけたいと考えています。つまりTOP2()table集計を実行します。5つの行をそれぞれ考慮する必要があります。結果は上位2つの値のtableです。

table集計関数を定義するには、pyflink.tableの基本クラスTableAggregateFunctionを拡張し、accumulate(...)という名前の1つ以上の評価関数を実装する必要があります。

集計関数の結果の型とaccumulatorの型は、次の2つの方法のいずれかで指定できます:

  • get_result_type()get_accumulator_type()という名前のメソッドを実装します。
  • 関数のインスタンスを pyflink.table.udf内のデコレータudtafでラップし、パラメータresult_typeaccumulator_typeを指定します。

以下の例は、独自の集計関数を定義し、クエリで呼び出す方法を示しています。

from pyflink.common import Row
from pyflink.table import DataTypes, TableEnvironment, EnvironmentSettings
from pyflink.table.expressions import col
from pyflink.table.udf import udtaf, TableAggregateFunction

class Top2(TableAggregateFunction):

    def emit_value(self, accumulator):
        yield Row(accumulator[0])
        yield Row(accumulator[1])

    def create_accumulator(self):
        return [None, None]

    def accumulate(self, accumulator, row):
        if row[0] is not None:
            if accumulator[0] is None or row[0] > accumulator[0]:
                accumulator[1] = accumulator[0]
                accumulator[0] = row[0]
            elif accumulator[1] is None or row[0] > accumulator[1]:
                accumulator[1] = row[0]

    def get_accumulator_type(self):
        return 'ARRAY<BIGINT>'

    def get_result_type(self):
        return 'ROW<a BIGINT>'


env_settings = EnvironmentSettings.in_streaming_mode()
table_env = TableEnvironment.create(env_settings)
# the result type and accumulator type can also be specified in the udtaf decorator:
# top2 = udtaf(Top2(), result_type=DataTypes.ROW([DataTypes.FIELD("a", DataTypes.BIGINT())]), accumulator_type=DataTypes.ARRAY(DataTypes.BIGINT()))
top2 = udtaf(Top2())
t = table_env.from_elements([(1, 'Hi', 'Hello'),
                             (3, 'Hi', 'hi'),
                             (5, 'Hi2', 'hi'),
                             (7, 'Hi', 'Hello'),
                             (2, 'Hi', 'Hello')],
                            ['a', 'b', 'c'])

# call function "inline" without registration in Table API
t.group_by(col('b')).flat_aggregate(top2).select(col('*')).execute().print()

# the result is:
#      b    a
# 0  Hi2  5.0
# 1  Hi2  NaN
# 2   Hi  7.0
# 3   Hi  3.0

Top2クラスのaccumulate(...)メソッドは2つの入力を受け取ります:1つ目はaccumulatorで、2つ目はユーザ定義の入力です。結果を計算するには、accumulatorは蓄積されたすべたのデータの上位2つの値を保持する必要があります。accumulatorはFlinkのチェックポイントの仕組みで自動的に管理され、フェイルオーバー時には復元されて、確実に1回のセマンティクスを保証します。 結果の値はランキングのインデックスとともに発行されます。

必須のメソッドとオプションのメソッド #

以下のメソッドは各TableAggregateFunctionで必須です:

  • create_accumulator()
  • accumulate(...)
  • emit_value(...)

以下のTableAggregateFunctionのメソッドはユースケースに応じて必要になります:

  • retract(...)は、グループ集計、外部結合などの現在の集計オペレーションの前に、撤回メッセージを生成する可能性のあるオペレーションがある場合に必要です。 このメソッドはオプションですが、UDTAFがどのようなユースケースでも使うことができるようにするために実装することを強くお勧めします。
  • 結果の型とaccumulatorの型がudtafデコレータ内で指定されない場合、get_result_type()get_accumulator_type()が必要です。

ListViewとMapView #

Aggregation functionと同様に、table集計関数でListViewとMapViewを使うこともできます。

from pyflink.common import Row
from pyflink.table import ListView
from pyflink.table.types import DataTypes
from pyflink.table.udf import TableAggregateFunction

class ListViewConcatTableAggregateFunction(TableAggregateFunction):

    def emit_value(self, accumulator):
        result = accumulator[1].join(accumulator[0])
        yield Row(result)
        yield Row(result)

    def create_accumulator(self):
        return Row(ListView(), '')

    def accumulate(self, accumulator, *args):
        accumulator[1] = args[1]
        accumulator[0].add(args[0])

    def get_accumulator_type(self):
        return DataTypes.ROW([
            DataTypes.FIELD("f0", DataTypes.LIST_VIEW(DataTypes.STRING())),
            DataTypes.FIELD("f1", DataTypes.BIGINT())])

    def get_result_type(self):
        return DataTypes.ROW([DataTypes.FIELD("a", DataTypes.STRING())])
inserted by FC2 system