在自然语言处理(NLP)领域,Hugging Face 是不可或缺的处理库,而 Spark 则是大数据处理的必备工具。将两者的优势结合起来,可以实现高效的 NLP 大数据处理。以下是结合 Hugging Face 和 Spark 的两种方法,基于 Spark & PySpark 3.3.1 版本进行探索。
方法一:升级 Spark 版本至 3.4 及以上
如果你愿意升级 Spark 版本到 3.4 或更高版本,那么结合 Hugging Face 和 Spark 将变得非常方便。Spark 3.4 及以上版本天然支持加载模型进行预测。
关键步骤说明:
- 模型加载策略:需要为每个 Worker 单独加载模型,确保模型在分布式环境中的可用性。
- 文件夹管理:在加载 Hugging Face 预训练模型之前,务必删除之前的模型文件夹,防止加载失败。
注:如果图片无法显示,请检查链接合法性或稍后重试。
方法二:基于 Spark 3.3.1 的手动封装接口
如果你希望保持当前的 Spark 3.3.1 版本,那么可以通过手动封装接口来实现 Hugging Face 和 Spark 的结合。以下是详细的代码实现和关键说明。
封装分布式的模型缓存
为了高效管理模型加载和缓存,我们从spark3.4的源代码中抽取了一个分布式的模型缓存机制:
from collections import OrderedDict
from threading import Lock
from typing import Callable, Optional
from uuid import UUID
class ModelCache:
"""Cache for model prediction functions on executors.
This requires the `spark.python.worker.reuse` configuration to be set to `true`, otherwise a
new python worker (with an empty cache) will be started for every task.
If a python worker is idle for more than one minute (per the IDLE_WORKER_TIMEOUT_NS setting in
PythonWorkerFactory.scala), it will be killed, effectively clearing the cache until a new python
worker is started.
Caching large models can lead to out-of-memory conditions, which may require adjusting spark
memory configurations, e.g. `spark.executor.memoryOverhead`.
"""
_models: OrderedDict = OrderedDict()
_capacity: int = 3 # "reasonable" default size for now, make configurable later, if needed
_lock: Lock = Lock()
@staticmethod
def add(uuid: UUID, predict_fn: Callable) -> None:
with ModelCache._lock:
ModelCache._models[uuid] = predict_fn
ModelCache._models.move_to_end(uuid)
if len(ModelCache._models) > ModelCache._capacity:
ModelCache._models.popitem(last=False)
@staticmethod
def get(uuid: UUID) -> Optional[Callable]:
with ModelCache._lock:
predict_fn = ModelCache._models.get(uuid)
if predict_fn:
ModelCache._models.move_to_end(uuid)
return predict_fn
封装处理逻辑
from __future__ import annotations
import os
import argparse
import random
import logging
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, column, encode
from pyspark.sql.types import *
from datetime import datetime, timedelta
import requests as req
from io import BytesIO
import numpy as np
import uuid
import inspect
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import (
ArrayType,
ByteType,
DataType,
DoubleType,
FloatType,
IntegerType,
LongType,
ShortType,
StringType,
StructType,
)
from typing import Any, Callable, Iterator, List, Mapping, TYPE_CHECKING, Tuple, Union, Optional
supported_scalar_types = (
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType,
StringType,
)
hadoop = os.path.join(os.environ['HADOOP_COMMON_HOME'], 'bin/hadoop')
def init_spark():
"""初始化 SparkSession 配置"""
spark = SparkSession.builder \
.config("spark.sql.caseSensitive", "false") \
.config("spark.shuffle.spill", "true") \
.config("spark.shuffle.spill.compress", "true") \
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
.config("metastore.catalog.default", "hive") \
.config("spark.sql.hive.convertMetastoreOrc", "true") \
.config("spark.kryoserializer.buffer.max", "1024m") \
.config("spark.kryoserializer.buffer", "64m") \
.config("spark.driver.maxResultSize","4g") \
.config("spark.sql.broadcastTimeout", "36000") \
.enableHiveSupport() \
.getOrCreate()
return spark
def system_command(command):
"""执行系统命令"""
code = os.system(command)
if code != 0:
logging.error(f"Command: ({command}) execute failed.")
else:
logging.info(f"Command: ({command}) execute succeed.")
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(usage="it's usage tip.",
description="user tags prefer")
parser.add_argument("--db", default="", help="hive表")
parser.add_argument("--date", default="", help="日期")
parser.add_argument("--output_path", default="", help="输出路径")
parser.add_argument("--batch_size", default=16, help="输出路径")
return parser.parse_args()
def _batched(
data: Union[pd.Series, pd.DataFrame, Tuple[pd.Series]], batch_size: int
) -> Iterator[pd.DataFrame]:
"""将 pandas dataframe/series 分批处理"""
if isinstance(data, pd.DataFrame):
df = data
elif isinstance(data, pd.Series):
df = pd.concat((data,), axis=1)
else: # isinstance(data, Tuple[pd.Series])
df = pd.concat(data, axis=1)
index = 0
data_size = len(df)
while index < data_size:
yield df.iloc[index : index + batch_size]
index += batch_size
def _is_tensor_col(data: Union[pd.Series, pd.DataFrame]) -> bool:
"""检查数据是否为张量列"""
if isinstance(data, pd.Series):
return data.dtype == np.object_ and isinstance(data.iloc[0], (np.ndarray, list))
elif isinstance(data, pd.DataFrame):
return any(data.dtypes == np.object_) and any(
[isinstance(d, (np.ndarray, list)) for d in data.iloc[0]]
)
else:
raise ValueError(
"Unexpected data type: {}, expected pd.Series or pd.DataFrame.".format(type(data))
)
def _has_tensor_cols(data: Union[pd.Series, pd.DataFrame, Tuple[pd.Series]]) -> bool:
"""检查输入是否包含张量值列"""
if isinstance(data, (pd.Series, pd.DataFrame)):
return _is_tensor_col(data)
else: # isinstance(data, Tuple)
return any(_is_tensor_col(elem) for elem in data)
def _validate_and_transform_multiple_inputs(
batch: pd.DataFrame, input_shapes: List[Optional[List[int]]], num_input_cols: int
) -> List[np.ndarray]:
"""验证并转换多个输入"""
multi_inputs = [batch[col].to_numpy() for col in batch.columns]
if input_shapes:
if len(input_shapes) == num_input_cols:
multi_inputs = [
np.vstack(v).reshape([-1] + input_shapes[i]) # type: ignore
if input_shapes[i]
else v
for i, v in enumerate(multi_inputs)
]
if not all([len(x) == len(batch) for x in multi_inputs]):
raise ValueError("Input data does not match expected shape.")
else:
raise ValueError("input_tensor_shapes must match columns")
return multi_inputs
def _validate_and_transform_single_input(
batch: pd.DataFrame,
input_shapes: List[Optional[List[int]]],
has_tensors: bool,
has_tuple: bool,
) -> np.ndarray:
"""验证并转换单个输入"""
# 处理逻辑省略(与原文一致)
return single_input
def _validate_and_transform_prediction_result(
preds: Union[np.ndarray, Mapping[str, np.ndarray], List[Mapping[str, Any]]],
num_input_rows: int,
return_type: DataType,
) -> Union[pd.DataFrame, pd.Series]:
"""验证并转换预测结果"""
# 处理逻辑省略(与原文一致)
return pd.DataFrame(preds)
def predict_batch_udf(
make_predict_fn: Callable[
[],
PredictBatchFunction,
],
*,
return_type: DataType,
batch_size: int,
input_tensor_shapes: Optional[Union[List[Optional[List[int]]], Mapping[int, List[int]]]] = None,
):
"""定义批量预测的 Pandas UDF"""
model_uuid = uuid.uuid4()
def predict(data: Iterator[Union[pd.Series, pd.DataFrame]]) -> Iterator[pd.DataFrame]:
from model_cache import ModelCache
predict_fn = ModelCache.get(model_uuid)
if not predict_fn:
predict_fn = make_predict_fn()
ModelCache.add(model_uuid, predict_fn)
signature = inspect.signature(predict_fn)
num_expected_cols = len(signature.parameters)
input_shapes: List[Optional[List[int]]]
if isinstance(input_tensor_shapes, Mapping):
input_shapes = [None] * num_expected_cols
for index, shape in input_tensor_shapes.items():
input_shapes[index] = shape
else:
input_shapes = input_tensor_shapes # type: ignore
for pandas_batch in data:
has_tuple = isinstance(pandas_batch, Tuple) # type: ignore
has_tensors = _has_tensor_cols(pandas_batch)
if has_tensors and not input_shapes:
raise ValueError("Tensor columns require input_tensor_shapes")
for batch in _batched(pandas_batch, batch_size):
num_input_rows = len(batch)
num_input_cols = len(batch.columns)
if num_input_cols == num_expected_cols and num_expected_cols > 1:
multi_inputs = _validate_and_transform_multiple_inputs(
batch, input_shapes, num_input_cols
)
preds = predict_fn(*multi_inputs)
elif num_expected_cols == 1:
single_input = _validate_and_transform_single_input(
batch, input_shapes, has_tensors, has_tuple
)
preds = predict_fn(single_input)
else:
msg = "Model expected {} inputs, but received {} columns"
raise ValueError(msg.format(num_expected_cols, num_input_cols))
yield _validate_and_transform_prediction_result(
preds, num_input_rows, return_type
) # type: ignore
return pandas_udf(predict, return_type) # type: ignore[call-overload]
def extract_text_embedding(model, tokenizer, sentence):
"""提取文本嵌入向量"""
inputs = tokenizer(sentence, return_tensors='pt', max_length=32, padding=True, truncation=True)
embeddings = model(**inputs)
embeddings = embeddings.pooler_output
embeddings = embeddings.tolist()
for i in range(len(embeddings)):
embeddings[i] = [round(c,4) for c in embeddings[i]]
return np.array(embeddings, dtype=np.float32)
if __name__ == "__main__":
args = parse_args()
spark = init_spark()
### 读取数据
df = spark.sql(f"""
select article_id, title
from xxx
""")
def predict_embedding():
system_command(f"""rm -rf ./bert-base-chinese""")
system_command(f"""{hadoop} fs -get /path/to/bert-base-chinese""")
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('./bert-base-chinese')
text_model = BertModel.from_pretrained('./bert-base-chinese')
def predict(inputs):
sentence = inputs.tolist()
embeddings = extract_text_embedding(text_model, tokenizer, sentence)
return embeddings
return predict
predict_embedding_udf = predict_batch_udf(predict_embedding,
return_type=ArrayType(StringType()),
batch_size=100)
df.withColumn("title_embedding", predict_embedding_udf("title")).show(5)
spark.stop()
del spark
关键点说明:
- 模型加载与缓存:通过
predict_batch_udf
函数封装预测逻辑,利用模型缓存避免重复加载,提高效率。 - 批量处理:使用
_batched
函数将数据分批处理,避免内存溢出,适合大数据场景。 - 类型转换与验证:通过
_validate_and_transform
系列函数确保输入输出类型匹配,提高代码健壮性。 - Hugging Face 模型集成:在
predict_embedding
函数中加载 Hugging Face 的 BERT 模型,并定义预测逻辑。 - typing语法修改:python3.10以的typing是不支持|语法的,需要改成Union进行类型的或推断
方法比较
对比维度 | 方法一:升级 Spark 至 3.4+ | 方法二:基于 Spark 3.3.1 手动封装接口 |
---|---|---|
实现难度 | 较低,依托新版本特性 | 较高,需手动实现缓存及接口封装 |
模型加载方式 | 每个 Worker 单独加载模型 | 每个 Worker 单独加载模型,并通过分布式缓存机制复用 |
文件管理要求 | 需提前删除旧模型文件夹防止加载失败 | 加载前删除旧模型文件夹 |
代码复用性 | 可直接使用新版本 API,代码简洁 | 需手动封装,代码量较大,但更具灵活性 |
性能优化 | 新版本可能自带优化 | 可通过调整缓存策略、批量处理逻辑等进行精细优化 |
适用场景 | 适合可升级环境,追求快速开发 | 适合无法升级环境,或对性能和资源管理有更高要求的场景 |
可维护性 | 依赖新版本稳定性,升级后需充分测试 | 自定义逻辑较多,需额外维护封装的接口和缓存机制 |
扩展性 | 依赖 Spark 新版本的更新节奏 | 可根据项目需求灵活扩展自定义功能 |
社区支持 | 可直接参考官方文档和社区对新版本的案例 | 需结合旧版本社区经验,同时参考自定义实现的维护文档 |
资源消耗 | 新版本可能对硬件有新要求 | 可通过优化缓存和批处理逻辑,更精细地控制资源使用 |
通过以上两种方法,可以在不同 Spark 版本环境下实现 Hugging Face 和 Spark 的结合,充分发挥两者在 NLP 和大数据处理中的优势,推荐第二种,更加可控一些。
本文由博客一文多发平台 OpenWrite 发布!
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。