基于Puma框架的GPT2安全推理#

在本 lab 中,我们展示如何使用 Puma 基于一个预训练的 GPT-2 模型安全生成文本。

首先,我们展示如何使用 JAX 和 Hugging Face transformers 库基于预训练 GPT-2 模型生成文本。然后,我们展示如何通过少量代码修改在 Puma 上生成文本。

以下代码仅作为示例,请勿在生产环境直接使用。

本教程可能需要比 16c48g 更多的资源。

Puma 是什么?#

Puma 是一个快速且准确的端到端安全三方安全Transformer模型推理框架。 Puma 为 \(\mathsf{GeLU}`和 :math:\)mathsf{Softmax}`等开销较大的复杂非线性函数设计了高质量的近似函数,这在保证模型性能的同时大大减少了安全推理的开销。除此之外,我们还设计了安全的 \(\mathsf{Embedding}`和 :math:\)mathsf{LayerNorm}`算子实现,从而在不改变模型结构的前提下实现安全推理。Puma 比之前当前最优的方案之一 MPCFormer(ICLR 2023)高效2倍左右,并且在不对提供的模型微调的前提下达到了和明文同水平的准确率等指标(之前的安全Transformer推理框架均需要在改变模型结构后进一步微调)。

使用 JAX/Flax 通过 GPT-2 生成文本#

安装 transformers 库#

[ ]:
import sys

!{sys.executable} -m pip install transformers[flax]

transformers 库要求的 JAX 版本与 SPU 不一致,但不影响运行本教程的示例。

加载预训练 GPT-2 模型”#

请参考 该文档 获取更多 Flax 运行 GPT-2 的细节。

[2]:
from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config

tokenizer = AutoTokenizer.from_pretrained("gpt2")
pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

为了劫持GPT2模型中的 GeLU 函数,需要将``self.act``修改为`jax.nn.gelu`` 。例如,将``transformers/src/transformers/models/gpt2/modeling_flax_gpt2.py``,296行

hidden_states = self.act(hidden_states)

修改为

hidden_states = jax.nn.gelu(hidden_states)

定义文本生成函数#

我们使用 贪心搜索策略 来生成文本。

[3]:
def text_generation(input_ids, params):
    config = GPT2Config()
    model = FlaxGPT2LMHeadModel(config=config)

    for _ in range(10):
        outputs = model(input_ids=input_ids, params=params)
        next_token_logits = outputs[0][0, -1, :]
        next_token = jnp.argmax(next_token_logits)
        input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1)
    return input_ids

在 CPU 上生成文本#

[4]:
import jax.numpy as jnp

inputs_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')
outputs_ids = text_generation(inputs_ids, pretrained_model.params)

print('-' * 65 + '\nRun on CPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)
2023-06-15 17:07:55.627043: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
2023-06-15 17:07:55.627112: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
2023-06-15 17:07:55.627118: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
-----------------------------------------------------------------
Run on CPU:
-----------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever
-----------------------------------------------------------------

这里我们生成了 10 个 tokens。请记住生成的文本,我们接下来会在 SPU 上生成文本。

在 SPU 上生成文本#

引入需要的库并配置相关优化#

[ ]:
import secretflow as sf
from typing import Any, Callable, Dict, Optional, Tuple, Union
import jax.nn as jnn
import flax.linen as nn
from flax.linen.linear import Array
import jax
import argparse
import spu.utils.distributed as ppd
import spu.intrinsic as intrinsic
import spu.spu_pb2 as spu_pb2
from contextlib import contextmanager

copts = spu_pb2.CompilerOptions()
copts.enable_pretty_print = False
copts.xla_pp_kind = 2
# enable x / broadcast(y) -> x * broadcast(1/y)
copts.enable_optimize_denominator_with_broadcast = True
Array = Any

# In case you have a running secretflow runtime already.
sf.shutdown()

劫持 Softmax ,定义其优化函数#

[ ]:
def hack_softmax(
    x: Array,
    axis: Optional[Union[int, Tuple[int, ...]]] = -1,
    where: Optional[Array] = None,
    initial: Optional[Array] = None,
) -> Array:
    x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
    x = x - x_max

    # exp on large negative is clipped to zero
    b = x > -14
    nexp = jnp.exp(x)

    divisor = jnp.sum(nexp, axis, where=where, keepdims=True)

    return b * (nexp / divisor)


@contextmanager
def hack_softmax_context(msg: str, enabled: bool = False):
    if not enabled:
        yield
        return
    # hijack some target functions
    raw_softmax = jnn.softmax
    jnn.softmax = hack_softmax
    yield
    # recover back
    jnn.softmax = raw_softmax

劫持 GeLU ,定义其优化函数#

[ ]:
def hack_gelu(
    x: Array,
    axis: Optional[Union[int, Tuple[int, ...]]] = -1,
    where: Optional[Array] = None,
    initial: Optional[Array] = None,
) -> Array:
    b0 = x < -4.0
    b1 = x < -1.95
    b2 = x > 3.0
    b3 = b1 ^ b2 ^ True  # x in [-1.95, 3.0]
    b4 = b0 ^ b1  # x in [-4, -1.95]

    # seg1 = a[3] * x^3 + a[2] * x^2 + a[1] * x + a[0]
    # seg2 = b[6] * x^6 + b[4] * x^4 + b[2] * x^2 + b[1] * x + b[0]
    a_coeffs = jnp.array(
        [
            -0.5054031199708174,
            -0.42226581151983866,
            -0.11807612951181953,
            -0.011034134030615728,
        ]
    )
    b_coeffs = jnp.array(
        [
            0.008526321541038084,
            0.5,
            0.3603292692789629,
            0.0,
            -0.037688200365904236,
            0.0,
            0.0018067462606141187,
        ]
    )
    x2 = jnp.square(x)
    x3 = jnp.multiply(x, x2)
    x4 = jnp.square(x2)
    x6 = jnp.square(x3)

    seg1 = a_coeffs[3] * x3 + a_coeffs[2] * x2 + a_coeffs[1] * x + a_coeffs[0]
    seg2 = (
        b_coeffs[6] * x6
        + b_coeffs[4] * x4
        + b_coeffs[2] * x2
        + b_coeffs[1] * x
        + b_coeffs[0]
    )

    ret = b2 * x + b4 * seg1 + b3 * seg2

    return ret


@contextmanager
def hack_gelu_context(msg: str, enabled: bool = False):
    if not enabled:
        yield
        return
    # hijack some target functions
    raw_gelu = jnn.gelu
    jnn.gelu = hack_gelu
    yield
    # recover back
    jnn.gelu = raw_gelu

针对GPT2模型启动 Puma#

[5]:
sf.init(['alice', 'bob', 'carol'], address='local')

alice, bob = sf.PYU('alice'), sf.PYU('bob')
conf = sf.utils.testing.cluster_def(['alice', 'bob', 'carol'])
conf['runtime_config']['protocol'] = 'ABY3'
conf['runtime_config']['field'] = 'FM64'
conf['runtime_config']['fxp_exp_mode'] = 0
conf['runtime_config']['fxp_exp_iters'] = 5

spu = sf.SPU(conf)


def get_model_params():
    pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
    return pretrained_model.params


def get_token_ids():
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    return tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')


model_params = alice(get_model_params)()
input_token_ids = bob(get_token_ids)()

device = spu
model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)

with hack_softmax_context("hijack jax softmax", enabled=True), hack_gelu_context(
    "hack jax gelu", enabled=True
):
    output_token_ids = spu(text_generation, copts=copts)(
        input_token_ids_, model_params_
    )
WARNING:root:Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
2023-06-15 17:08:14,157 INFO worker.py:1538 -- Started a local Ray instance.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
(pid=2109508) Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
(pid=2109408) Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
(pid=2121303) Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
(pid=2121304) Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
(pid=2121301) Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
(_run pid=2109408) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2109408) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=2109408) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2109408) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=2109508) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2109508) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=2109508) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2109508) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=2109408) [2023-06-15 17:08:24.221] [info] [thread_pool.cc:30] Create a fixed thread pool with size 127

检查 Puma 的输出#

可以发现,在 Puma 上运行 GPT-2 推理非常简单。接下来让我们明文显示 SPU 生成的文本。

[6]:
outputs_ids = sf.reveal(output_token_ids)
print('-' * 65 + '\nRun on SPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)
(_spu_compile pid=2109408) 2023-06-15 17:09:12.722333: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(_spu_compile pid=2109408) 2023-06-15 17:09:12.722414: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(_spu_compile pid=2109408) 2023-06-15 17:09:12.722421: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
(SPURuntime(device_id=None, party=bob) pid=2121303) 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
(SPURuntime(device_id=None, party=alice) pid=2121301) 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
(SPURuntime(device_id=None, party=carol) pid=2121304) 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
-----------------------------------------------------------------
Run on SPU:
-----------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever
-----------------------------------------------------------------

可以发现,Puma 生成的文本与 CPU 生成的文本是完全一致的!

本教程到此结束。更多关于Puma的测试,请参考https://github.com/secretflow/spu/tree/main/examples/python/ml/flax_llama7b