【大数据】Presto开发自定义聚合函数
Presto 在交互式查詢(xún)?nèi)蝿?wù)中擔(dān)當(dāng)著重要的職責(zé)。隨著越來(lái)越多的人開(kāi)始使用 SQL 在 Presto 上分析數(shù)據(jù),我們發(fā)現(xiàn)需要將一些業(yè)務(wù)邏輯開(kāi)發(fā)成類(lèi)似 Hive 中的 UDF,提高 SQL 使用人員的效率,同時(shí)也保證 Hive 和 Presto 環(huán)境中的 UDF 統(tǒng)一。
1、Presto函數(shù)介紹
在此之前先簡(jiǎn)單介紹下UDF和UDAF,UDF叫做用戶(hù)自定義函數(shù),而UDAF叫做用戶(hù)自定義聚合函數(shù),區(qū)別就在于UDF不會(huì)保存狀態(tài),一行輸入一行輸出,而UDAF會(huì)涉及到狀態(tài)的保存,通過(guò)聚合多個(gè)節(jié)點(diǎn)的數(shù)據(jù)來(lái)轉(zhuǎn)換為最終的輸出結(jié)果。
在 Presto 中,函數(shù)大體分為三種:scalar,aggregation 和 window 類(lèi)型。分別如下:
1)scalar標(biāo)量函數(shù),簡(jiǎn)單來(lái)說(shuō)就是 Java 中的一個(gè)靜態(tài)方法,本身沒(méi)有任何狀態(tài)(不保存數(shù)據(jù),一行輸入一行輸出)。
2)aggregation累積狀態(tài)的函數(shù),或聚集函數(shù),如count,avg。如果只是單節(jié)點(diǎn),單機(jī)狀態(tài)可以直接用一個(gè)變量存儲(chǔ)即可,但是presto是分布式計(jì)算引擎,狀態(tài)數(shù)據(jù)會(huì)在多個(gè)節(jié)點(diǎn)之間傳輸,因此狀態(tài)數(shù)據(jù)需要被序列化成 Presto 的內(nèi)部格式才可以被傳輸。簡(jiǎn)單來(lái)說(shuō)Aggregation對(duì)應(yīng)于多行輸入一行輸出。
3)window 窗口函數(shù),窗口函數(shù)在查詢(xún)結(jié)果的行上進(jìn)行計(jì)算,執(zhí)行順序在HAVING子句之后,ORDER BY子句之前。在 Presto SQL 中,窗口函數(shù)的語(yǔ)法形式如下:
windowFunction(arg1,....argn) OVER([PARTITION BY<...>] [ORDER BY<...>] [RANGE|ROWS BETWEEN AND])窗口函數(shù)語(yǔ)法由關(guān)鍵字OVER觸發(fā),且包含三個(gè)子句:
PARTITION BY: 指定輸入行分區(qū)的規(guī)則,類(lèi)似于聚合函數(shù)的GROUP BY子句,不同分區(qū)里的計(jì)算互不干擾(窗口函數(shù)的計(jì)算是并發(fā)進(jìn)行的,并發(fā)數(shù)和partition數(shù)量一致),缺省時(shí)將所有數(shù)據(jù)行視為一個(gè)分區(qū)
ORDER BY: 決定了窗口函數(shù)處理輸入行的順序
RANGE|ROWS BETWEEN AND: 指定窗口邊界,不常用,缺省時(shí)的窗口為當(dāng)前行所在的分區(qū)第一行到當(dāng)前行。
2、自定義函數(shù)
官方文檔地址:https://prestodb.io/docs/current/develop/functions.html
2.1自定義Scalar函數(shù)的實(shí)現(xiàn)
2.1.1定義一個(gè)java類(lèi)
1)用 @ScalarFunction 的 Annotation 標(biāo)記實(shí)現(xiàn)業(yè)務(wù)邏輯的靜態(tài)方法。
2)用 @Description 描述函數(shù)的作用,這里的內(nèi)容會(huì)在 SHOW FUNCTIONS 中顯示。
3)用@SqlType 標(biāo)記函數(shù)的返回值類(lèi)型,如返回字符串,因此是 StandardTypes.VARCHAR。
4)Java 方法的返回值必須使用 Presto 內(nèi)部的序列化方式,因此字符串類(lèi)型必須返回 Slice, 使用 Slices.utf8Slice 方法可以方便的將 String 類(lèi)型轉(zhuǎn)換成 Slice 類(lèi)型
public class ExampleStringFunction {@ScalarFunction("lowercaser")@Description("converts the string to alternating case")@SqlType(StandardTypes.VARCHAR)public static Slice lowercaser(@SqlType(StandardTypes.VARCHAR) Slice slice){String argument = slice.toStringUtf8();return Slices.utf8Slice(argument.toLowerCase());} }2.2 自定義Aggregation函數(shù)
2.2.1實(shí)現(xiàn)原理步驟
Presto 把 Aggregation 函數(shù)分解成三個(gè)步驟執(zhí)行:
1、input(state, data): 針對(duì)每條數(shù)據(jù),執(zhí)行 input 函數(shù)。這個(gè)過(guò)程是并行執(zhí)行的,因此在每個(gè)有數(shù)據(jù)的節(jié)點(diǎn)都會(huì)執(zhí)行,最終得到多個(gè)累積的狀態(tài)數(shù)據(jù)。
2、combine(state1, state2):將所有節(jié)點(diǎn)的狀態(tài)數(shù)據(jù)聚合起來(lái),多次執(zhí)行,直至所有狀態(tài)數(shù)據(jù)被聚合成一個(gè)最終狀態(tài),也就是 Aggregation 函數(shù)的輸出結(jié)果。
3、output(final_state, out):最終輸出結(jié)果到一個(gè) BlockBuilder。
2.2.2 具體代碼實(shí)現(xiàn)過(guò)程
1、定義一個(gè) Java 類(lèi),使用 @AggregationFunction 標(biāo)記為 Aggregation 函數(shù)
2、使用 @InputFunction、 @CombineFunction、@OutputFunction 分別標(biāo)記計(jì)算函數(shù)、合并結(jié)果函數(shù)和最終輸出函數(shù)在 Plugin 處注冊(cè) Aggregation 函數(shù)
3、一個(gè)繼承AccumulatorState的State接口,get和set方法
4、并使用 @AccumulatorStateMetadata 提供序列化(stateSerializerClass指定)和 Factory 類(lèi)信息(stateFactoryClass指定)。自己寫(xiě)一個(gè)序列化類(lèi)和一個(gè)工廠(chǎng)類(lèi)。(復(fù)雜數(shù)據(jù)類(lèi)型需要:自定義類(lèi)保存狀態(tài)、Map、List等)
簡(jiǎn)單類(lèi)型Aggregation
對(duì)于簡(jiǎn)單數(shù)據(jù)類(lèi)型的聚合函數(shù)編寫(xiě)比較簡(jiǎn)單,實(shí)現(xiàn)一個(gè)包含input、combine、output的aggregation和一個(gè)狀態(tài)設(shè)定接口State提供get、set方法即可,不用去關(guān)心序列化和狀態(tài)保存問(wèn)題。
Aggregation:
LongAndDoubleState :寫(xiě)一個(gè)接口實(shí)現(xiàn)繼承自AccumulatorState類(lèi),提供get、set方法即可。
public interface LongAndDoubleStateextends AccumulatorState {long getLong();void setLong(long value);double getDouble();void setDouble(double value); }復(fù)雜類(lèi)型Aggregation
對(duì)于復(fù)雜數(shù)據(jù)類(lèi)型則需要提供序列化機(jī)制,你要序列化那些東西都是由你來(lái)制指定的。在AccumulatorState的接口上用注解指定@AccumulatorStateMetadata 提供序列化(stateSerializerClass指定)和 Factory 類(lèi)信息(stateFactoryClass指定),自定義一個(gè)序列化器和序列化工廠(chǎng)類(lèi),實(shí)現(xiàn)類(lèi)的序列化和反序列化。
Aggregation類(lèi): 這個(gè)類(lèi)實(shí)現(xiàn)比較簡(jiǎn)單,和簡(jiǎn)單數(shù)據(jù)類(lèi)型的實(shí)現(xiàn)一樣,input、combine、output。
@AggregationFunction("presto_collect") public class CollectListAggregation {@InputFunctionpublic static void input(@AggregationState CollectState state, @SqlType(StandardTypes.VARCHAR) Slice id,@SqlType(StandardTypes.VARCHAR) Slice key) {try {CollectListStats stats = state.get();if (stats == null) {stats = new CollectListStats();state.set(stats);}int inputId = Integer.parseInt(id.toStringUtf8());String inputKey = key.toStringUtf8();stats.addCollectList(inputId,inputKey, 1);} catch (Exception e) {throw new RuntimeException(e+" --------- input err");}}@CombineFunctionpublic static void combine(@AggregationState CollectState state, CollectState otherState) {try {CollectListStats collectListStats = state.get();CollectListStats oCollectListStats = otherState.get();if(collectListStats == null) {state.set(oCollectListStats);} else {collectListStats.mergeWith(oCollectListStats);}}catch (Exception e) {throw new RuntimeException(e+" --------- combine err");}}@OutputFunction(StandardTypes.*VARCHAR*)public static void output(@AggregationState CollectState state, BlockBuilder out) {try {CollectListStats stats = state.get();if (stats == null) {out.appendNull();return;}// 統(tǒng)計(jì)Slice result = stats.getCollectResult();if (result == null) {out.appendNull();} else {VarcharType.VARCHAR.writeSlice(out, result);}} catch (Exception e) {throw new RuntimeException(e+" -------- output err");}} }狀態(tài)保存接口:
@AccumulatorStateMetadata(stateSerializerClass = CollectListStatsSerializer.class, stateFactoryClass = CollectListStatsFactory.class) public interface CollectState extends AccumulatorState {CollectListStats get();void set(CollectListStats value); }存放數(shù)據(jù)的類(lèi):此類(lèi)需要實(shí)現(xiàn)數(shù)據(jù)的序列化和反序列化,這是最關(guān)鍵和比較麻煩的地方,貼一個(gè)例子,關(guān)鍵在于需要自己控制存儲(chǔ)空間以及數(shù)據(jù)的順序,和讀取的時(shí)候按照一定順序讀取。對(duì)于字符要先存儲(chǔ)長(zhǎng)度,然后是字節(jié),讀取則先讀取字符長(zhǎng)度,然后讀取這么長(zhǎng)的數(shù)據(jù),最后轉(zhuǎn)化為字符。
public class CollectListStats {private static final int INSTANCE_SIZE = ClassLayout.parseClass(CollectListStats.class).instanceSize();//<id,<key,value>>private Map<Integer,Map<String,Integer>> collectContainer = new HashMap<>();private long contentEstimatedSize = 0;private int keyByteLen = 0;private int keyListLen = 0;CollectListStats() {}CollectListStats(Slice serialized) {deserialize(serialized);}void addCollectList(Integer id, String key, int value) {if (collectContainer.containsKey(id)) {Map<String, Integer> tmpMap = collectContainer.get(id);if (tmpMap.containsKey(key)) {tmpMap.put(key, tmpMap.get(key)+value);}else{tmpMap.put(key,value);contentEstimatedSize += ( key.getBytes().length + SizeOf.SIZE_OF_INT*);keyByteLen += key.getBytes().length;keyListLen++;}} else {Map<String,Integer> tmpMap = new HashMap<String,Integer>();tmpMap.put(key, value);keyByteLen += key.getBytes().length;keyListLen++;collectContainer.put(id, tmpMap);contentEstimatedSize += SizeOf.SIZE_OF_INT;}}//[{id:1,{"aaa":3,"fadf":6},{}]Slice getCollectResult() {Slice jsonSlice = null;try {StringBuilder jsonStr = new StringBuilder();jsonStr.append("[");int collectLength = collectContainer.entrySet().size();for (Map.Entry<Integer, Map<String, Integer>> mapEntry : collectContainer.entrySet()) {Integer id = mapEntry.getKey();Map<String, Integer> vMap = mapEntry.getValue();jsonStr.append("{id:").append(id).append(",{");int vLength = vMap.entrySet().size();for (Map.Entry<String, Integer> vEntry : vMap.entrySet()) {String key = vEntry.getKey();Integer value = vEntry.getValue();jsonStr.append(key).append(":").append(value);vLength--;if (vLength != 0) {jsonStr.append(",");}}jsonStr.append("}");collectLength--;if (collectLength != 0) {jsonStr.append(",");}}jsonStr.append("]");jsonSlice = Slices.utf8Slice*(jsonStr.toString());} catch (Exception e) {throw new RuntimeException(e+" ---------- get CollectResult err");}return jsonSlice;}public void deserialize(Slice serialized) {try {SliceInput input = serialized.getInput();//外層map的長(zhǎng)度int collectStatsEntrySize = input.readInt();for (int collectCnt = 0; collectCnt < collectStatsEntrySize; collectCnt++) {int id = input.readInt();int keyEntrySize = input.readInt();for (int idCnt = 0; idCnt < keyEntrySize; idCnt++) {int keyBytesLen = input.readInt();byte[] keyBytes = new byte[keyBytesLen];for (int byteIdx = 0; byteIdx < keyBytesLen; byteIdx++) {keyBytes[byteIdx] = input.readByte();}String key = new String(keyBytes);int value = input.readInt();addCollectList(id, key, value);}}} catch (Exception e) {throw new RuntimeException(e+" ----- deserialize err");}}public Slice serialize() {SliceOutput builder = null;int requiredBytes = //對(duì)應(yīng) SliceOutput builder append的內(nèi)容所占用的空間SizeOf.SIZE_OF_INT*3 //id entry數(shù)目,id數(shù)值,key Entry數(shù)目\+ keyListLen * SizeOf.SIZE_OF_INT* //key bytes長(zhǎng)度\+ keyByteLen* //key byte總長(zhǎng)度\+ keyListLen * SizeOf.SIZE_OF_INT; //valuetry {// 序列化builder = Slices.*allocate*(requiredBytes).getOutput();for (Map.Entry<Integer,Map<String, Integer>> entry : collectContainer.entrySet()) {//id個(gè)數(shù)builder.appendInt(collectContainer.entrySet().size());//id 數(shù)值builder.appendInt(entry.getKey());Map<String, Integer> kMap = entry.getValue();builder.appendInt(kMap.entrySet().size());for (Map.Entry<String, Integer> vEntry : kMap.entrySet()) {byte[] keyBytes = vEntry.getKey().getBytes();builder.appendInt(keyBytes.length);builder.appendBytes(keyBytes);builder.appendInt(vEntry.getValue());}}return builder.getUnderlyingSlice();} catch (Exception e) {throw new RuntimeException(e+" ---- serialize err requiredBytes = " + requiredBytes + " keyByteLen= " + keyByteLen + " keyListLen = " + keyListLen);}}long estimatedInMemorySize() {return INSTANCE_SIZE + contentEstimatedSize;}void mergeWith(CollectListStats other) {if (other == null) {return;}for (Map.Entry<Integer,Map<String, Integer>> cEntry : other.collectContainer.entrySet()) {Integer id = cEntry.getKey();Map<String, Integer> kMap = cEntry.getValue();for (Map.Entry<String, Integer> kEntry : kMap.entrySet()) {addCollectList(id, kEntry.getKey(), kEntry.getValue());}}} }上面的例子我是直接從別人那兒拿過(guò)來(lái)的(個(gè)人比較懶:https://www.cnblogs.com/lrxvx/p/12558902.html),實(shí)際方式也很簡(jiǎn)單,就是實(shí)現(xiàn)序列化和反序列化方法以及一個(gè)管理存儲(chǔ)空間的方法。需要注意的是序列化和反序列化時(shí)候的順序一定要保證,Presto提供了許多屬性方式的選項(xiàng)如int、long、byte,對(duì)于String方式序列化,可以將String轉(zhuǎn)為byte再進(jìn)行序列化,思路就是先序列化一個(gè)長(zhǎng)度進(jìn)去,再將字節(jié)內(nèi)容序列化,反序列化的時(shí)候先讀length,再讀相應(yīng)的字節(jié)內(nèi)容轉(zhuǎn)為String就好了,而對(duì)象類(lèi)型的屬性,本質(zhì)上還是可以直接序列化屬性,反序列化時(shí)候重新創(chuàng)建對(duì)象,內(nèi)容沒(méi)變。Presto的序列化方式比較高效,原因是因?yàn)槲铱梢灾恍蛄谢蚁胍膶傩跃秃昧?#xff0c;缺點(diǎn)是擴(kuò)展性不足。
序列化類(lèi):
public class CollectListStatsSerializer implements AccumulatorStateSerializer<CollectState> {@Overridepublic Type getSerializedType() {return VARBINARY;}@Overridepublic void serialize(CollectState state, BlockBuilder out) {if (state.get() == null) {out.appendNull();} else {VARBINARY.writeSlice(out, state.get().serialize());}}@Overridepublic void deserialize(Block block, int index, CollectState state) {state.set(new CollectListStats(VARBINARY.getSlice(block, index)));} }序列化工廠(chǎng)類(lèi):
public class CollectListStatsFactory implements AccumulatorStateFactory<CollectState> {@Overridepublic CollectState createSingleState() {return new SingleState();}@Overridepublic Class<? extends CollectState> getSingleStateClass() {return SingleState.class;}@Overridepublic CollectState createGroupedState() {return new GroupState();}@Overridepublic Class<? extends CollectState> getGroupedStateClass() {return GroupState.class;}public static class GroupState implements GroupedAccumulatorState, CollectState {private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedDigestAndPercentileState.class).instanceSize();private final ObjectBigArray<CollectListStats> collectStatsList = new ObjectBigArray<>();private long size;private long groupId;@Overridepublic void setGroupId(long groupId) {this.groupId = groupId;}@Overridepublic void ensureCapacity(long size) {collectStatsList.ensureCapacity(size);}@Overridepublic CollectListStats get() {return collectStatsList.get(groupId);}@Overridepublic void set(CollectListStats value) {CollectListStats previous = get();if (previous != null) {size -= previous.estimatedInMemorySize();}collectStatsList.set(groupId, value);size += value.estimatedInMemorySize();}@Overridepublic long getEstimatedSize() {return INSTANCE_SIZE +size + collectStatsList.sizeOf();}}public static class SingleState implements CollectState{private CollectListStats stats;@Overridepublic CollectListStats get() {return stats;}@Overridepublic void set(CollectListStats value) {stats = value;}@Overridepublic long getEstimatedSize() {if (stats == null) {return 0;}return stats.estimatedInMemorySize();}} }驗(yàn)證自定義函數(shù)
當(dāng)我們開(kāi)發(fā)好自定義函數(shù)后如何驗(yàn)證呢,一種方式是使用Presto內(nèi)置函數(shù)注冊(cè)機(jī)制進(jìn)行單元測(cè)試,Presto 函數(shù)由MetadataManager中的FunctionRegistry進(jìn)行管理,開(kāi)發(fā)的函數(shù)要生效必須要先注冊(cè)到FunctionRegistry中。函數(shù)注冊(cè)是在 Presto 服務(wù)啟動(dòng)過(guò)程中進(jìn)行的,有以下兩種方式進(jìn)行函數(shù)注冊(cè)。
FunctionListBuilder builder = new FunctionListBuilder().window(RowNumberFunction.class).aggregate(ApproximateCountDistinctAggregation.class).scalar(RepeatFunction.class).function(MAP_HASH_CODE)......注冊(cè)好之后就可以編寫(xiě)相應(yīng)的單元測(cè)試代碼了。完整的Aggregation測(cè)試代碼如下:
import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.FunctionListBuilder; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.operator.aggregation.InternalAggregationFunction; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test;import static com.facebook.presto.block.BlockAssertions.createSlicesBlock; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;public class TestAggregation{private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();private static InternalAggregationFunction getAggregation(Type... arguments){return FUNCTION_AND_TYPE_MANAGER.getAggregateFunctionImplementation(FUNCTION_AND_TYPE_MANAGER.lookupFunction("presto_collect", fromTypes(arguments)));}private static final InternalAggregationFunction COLLECTION_AGGREGATION = getAggregation(VARCHAR, VARCHAR); //和Aggregation中的類(lèi)型對(duì)應(yīng),java類(lèi)型的Slice對(duì)應(yīng)Varchar@BeforeClasspublic void init(){FunctionListBuilder builder = new FunctionListBuilder().aggregate(CollectListAggregation.class);FUNCTION_AND_TYPE_MANAGER.registerBuiltInFunctions(builder.getFunctions());}@Testpublic void collectionAggregationTest(){String result="xxx"; //你期望的aggregation結(jié)果Slice str1= Slices.utf8Slice("x");Slice str2= Slices.utf8Slice("y");assertAggregation(COLLECTION_AGGREGATION,result,createSlicesBlock(str1, str2),createSlicesBlock(str1, str2));} }標(biāo)量函數(shù)單元測(cè)試
而對(duì)于標(biāo)量函數(shù)scalar的測(cè)試略有不同,示例如下:
public class TestBitwiseFunctionsextends AbstractTestFunctions {@Testpublic void testBitCount(){assertFunction("bit_count(0, 64)", BIGINT, 0L); //bit_count為標(biāo)量函數(shù)名,傳參,參數(shù)如果為String則用單引號(hào),參數(shù)類(lèi)型,期望結(jié)果} }當(dāng)然進(jìn)行單元測(cè)試后,我們期望到真實(shí)的庫(kù)中去驗(yàn)證,內(nèi)置函數(shù)滿(mǎn)足不了使用需求時(shí),就需要自行開(kāi)發(fā)函數(shù)來(lái)拓展函數(shù)庫(kù)。開(kāi)發(fā)者自行編寫(xiě)的拓展函數(shù)一般通過(guò)插件的方式進(jìn)行注冊(cè)。PluginManager在安裝插件時(shí)會(huì)調(diào)用插件的getFunctions()方法,將獲取到的函數(shù)集合通過(guò)MetadataManager的addFunctions方法進(jìn)行注冊(cè):
public class ExampleFunctionsPluginimplements Plugin {@Overridepublic Set<Class<?>> getFunctions(){return ImmutableSet.<Class<?>>builder().add(ExampleNullFunction.class).add(IsNullFunction.class).add(IsEqualOrNullFunction.class).add(ExampleStringFunction.class).add(ExampleAverageFunction.class).build();} }Presto 函數(shù)的注冊(cè)機(jī)制,新增和修改函數(shù)后,必須要重啟服務(wù)才能生效,所以目前還不支持真正的用戶(hù)自定義函數(shù)。插件函數(shù)進(jìn)行注冊(cè)之后,在resource下創(chuàng)建META-INF/services目錄,并創(chuàng)建文件名為com.facebook.presto.spi.Plugin的文件,并添加內(nèi)容:
xxx.xxx.xxx.ExampleFunctionsPlugin然后利用presto的插件打包,此時(shí)會(huì)在target目錄下生成zip文件,把xxx.zip解壓到${PRESTOHOME}/plugin,重啟presto服務(wù)即可進(jìn)行驗(yàn)證。
總的來(lái)說(shuō),Presto的UDF和UDAF開(kāi)發(fā)總結(jié)為一張圖:
注意:各個(gè)版本的Presto源碼有所不同,遇到類(lèi)不正確的對(duì)版本進(jìn)行調(diào)整,上面是用的Presto版本為0.264,更多的參考Presto的官方源碼https://github.com/prestodb/presto,而對(duì)于Persto的分組聚合查詢(xún)流程可以參見(jiàn):Presto中的分組聚合查詢(xún)流程
總結(jié)
以上是生活随笔為你收集整理的【大数据】Presto开发自定义聚合函数的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 抖音四面被拒,再战头条终获offer,面
- 下一篇: 经济与社会发展研究杂志社经济与社会发展研