GPT-2 Secure inference with Puma#
In this lab, we showcase how to run 3PC secure inference on a pre-trained GPT-2 model for text generation with Puma.
First, we show how to use JAX and the Hugging Face Transformers library for text generation with the pre-trained GPT-2 model. After that, we show how to use Puma on the top of SPU for secure text generation with minor modifications to the plaintext counterpart.
The following codes are demos only. It’s NOT for production due to system security concerns, please DO NOT use it directly in production.
This tutorial may need more resources than 16c48g.
What is Puma?#
Puma is a fast and accurate end-to-end 3-party secure Transformer models inference framework. Puma designs high quality approximations for expensive functions, such as \(\mathsf{GeLU}\) and \(\mathsf{Softmax}\), which significantly reduce the cost of secure inference while preserving the model performance. Additionally, we design secure \(\mathsf{Embedding}\) and \(\mathsf{LayerNorm}\) procedures that faithfully implement the desired functionality without undermining the Transformer architecture. Puma is approximately \(2\times\) faster than the state-of-the-art MPC framework MPCFormer (ICLR 2023) and has similar accuracy as plaintext models without fine-tuning (which the previous works failed to achieve).
Text generation using GPT-2 with JAX/FLAX#
Install the transformers library#
[ ]:
import sys
!{sys.executable} -m pip install transformers[flax]
The JAX version required by transformers is not satisfied with SPU. But it’s ok to run with the conflicted JAX with SPU in this example.
Load the pre-trained GPT-2 Model#
Please refer to this documentation for more details.
[2]:
from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config
tokenizer = AutoTokenizer.from_pretrained("gpt2")
pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
To hack GeLU function of GPT2, you need to change the self.act as jax.nn.gelu to hack gelu. For example, in transformers/src/transformers/models/gpt2/modeling_flax_gpt2.py, line 296:
hidden_states = self.act(hidden_states)
is changed as
hidden_states = jax.nn.gelu(hidden_states)
Define the text generation function#
We use a greedy search strategy for text generation here.
[3]:
def text_generation(input_ids, params):
config = GPT2Config()
model = FlaxGPT2LMHeadModel(config=config)
for _ in range(10):
outputs = model(input_ids=input_ids, params=params)
next_token_logits = outputs[0][0, -1, :]
next_token = jnp.argmax(next_token_logits)
input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1)
return input_ids
Run text generation on CPU#
[4]:
import jax.numpy as jnp
inputs_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')
outputs_ids = text_generation(inputs_ids, pretrained_model.params)
print('-' * 65 + '\nRun on CPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)
2023-06-15 17:07:55.627043: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
2023-06-15 17:07:55.627112: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
2023-06-15 17:07:55.627118: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
-----------------------------------------------------------------
Run on CPU:
-----------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever
-----------------------------------------------------------------
Here we generate 10 tokens. Keep the generated text in mind, we are going to generate text on SPU in the next step.
Run text generation on SPU#
Import the necessary libraries and config the optimizations#
[ ]:
import secretflow as sf
from typing import Any, Callable, Dict, Optional, Tuple, Union
import jax.nn as jnn
import flax.linen as nn
from flax.linen.linear import Array
import jax
import argparse
import spu.utils.distributed as ppd
import spu.intrinsic as intrinsic
import spu.spu_pb2 as spu_pb2
from contextlib import contextmanager
copts = spu_pb2.CompilerOptions()
copts.enable_pretty_print = False
copts.xla_pp_kind = 2
# enable x / broadcast(y) -> x * broadcast(1/y)
copts.enable_optimize_denominator_with_broadcast = True
Array = Any
# In case you have a running secretflow runtime already.
sf.shutdown()
Define the Softmax hijack function.#
[ ]:
def hack_softmax(
x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None,
) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
x = x - x_max
# exp on large negative is clipped to zero
b = x > -14
nexp = jnp.exp(x)
divisor = jnp.sum(nexp, axis, where=where, keepdims=True)
return b * (nexp / divisor)
@contextmanager
def hack_softmax_context(msg: str, enabled: bool = False):
if not enabled:
yield
return
# hijack some target functions
raw_softmax = jnn.softmax
jnn.softmax = hack_softmax
yield
# recover back
jnn.softmax = raw_softmax
Define the GeLU hijack function#
[ ]:
def hack_gelu(
x: Array,
axis: Optional[Union[int, Tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None,
) -> Array:
b0 = x < -4.0
b1 = x < -1.95
b2 = x > 3.0
b3 = b1 ^ b2 ^ True # x in [-1.95, 3.0]
b4 = b0 ^ b1 # x in [-4, -1.95]
# seg1 = a[3] * x^3 + a[2] * x^2 + a[1] * x + a[0]
# seg2 = b[6] * x^6 + b[4] * x^4 + b[2] * x^2 + b[1] * x + b[0]
a_coeffs = jnp.array(
[
-0.5054031199708174,
-0.42226581151983866,
-0.11807612951181953,
-0.011034134030615728,
]
)
b_coeffs = jnp.array(
[
0.008526321541038084,
0.5,
0.3603292692789629,
0.0,
-0.037688200365904236,
0.0,
0.0018067462606141187,
]
)
x2 = jnp.square(x)
x3 = jnp.multiply(x, x2)
x4 = jnp.square(x2)
x6 = jnp.square(x3)
seg1 = a_coeffs[3] * x3 + a_coeffs[2] * x2 + a_coeffs[1] * x + a_coeffs[0]
seg2 = (
b_coeffs[6] * x6
+ b_coeffs[4] * x4
+ b_coeffs[2] * x2
+ b_coeffs[1] * x
+ b_coeffs[0]
)
ret = b2 * x + b4 * seg1 + b3 * seg2
return ret
@contextmanager
def hack_gelu_context(msg: str, enabled: bool = False):
if not enabled:
yield
return
# hijack some target functions
raw_gelu = jnn.gelu
jnn.gelu = hack_gelu
yield
# recover back
jnn.gelu = raw_gelu
Launch Puma on GPT2:#
[5]:
sf.init(['alice', 'bob', 'carol'], address='local')
alice, bob = sf.PYU('alice'), sf.PYU('bob')
conf = sf.utils.testing.cluster_def(['alice', 'bob', 'carol'])
conf['runtime_config']['protocol'] = 'ABY3'
conf['runtime_config']['field'] = 'FM64'
conf['runtime_config']['fxp_exp_mode'] = 0
conf['runtime_config']['fxp_exp_iters'] = 5
spu = sf.SPU(conf)
def get_model_params():
pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
return pretrained_model.params
def get_token_ids():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
return tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')
model_params = alice(get_model_params)()
input_token_ids = bob(get_token_ids)()
device = spu
model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)
with hack_softmax_context("hijack jax softmax", enabled=True), hack_gelu_context(
"hack jax gelu", enabled=True
):
output_token_ids = spu(text_generation, copts=copts)(
input_token_ids_, model_params_
)
WARNING:root:Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
2023-06-15 17:08:14,157 INFO worker.py:1538 -- Started a local Ray instance.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
(pid=2109508) Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
(pid=2109408) Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
(pid=2121303) Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
(pid=2121304) Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
(pid=2121301) Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
(_run pid=2109408) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2109408) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=2109408) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2109408) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=2109508) INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
(_run pid=2109508) INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
(_run pid=2109508) INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
(_run pid=2109508) WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(_run pid=2109408) [2023-06-15 17:08:24.221] [info] [thread_pool.cc:30] Create a fixed thread pool with size 127
Check the Puma output#
As you can see, it’s very easy to run GPT-2 inference on Puma. Now let’s reveal the generated text from Puma.
[6]:
outputs_ids = sf.reveal(output_token_ids)
print('-' * 65 + '\nRun on SPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)
(_spu_compile pid=2109408) 2023-06-15 17:09:12.722333: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(_spu_compile pid=2109408) 2023-06-15 17:09:12.722414: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
(_spu_compile pid=2109408) 2023-06-15 17:09:12.722421: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
(SPURuntime(device_id=None, party=bob) pid=2121303) 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
(SPURuntime(device_id=None, party=alice) pid=2121301) 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
(SPURuntime(device_id=None, party=carol) pid=2121304) 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
-----------------------------------------------------------------
Run on SPU:
-----------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever
-----------------------------------------------------------------
As we can see, the generated text from Puma is exactly same as the generated text from CPU!
This is the end of the lab. For more benchmarks about Puma, please refer to: secretflow/spu