I tried to run GPT-J in Windows 11 GPU environment.

Hello all, I’ve been working on a GPT-J based on GPT-3. This is a note on how I ran EleutherAI’s GPT-J, which is based on GPT-3, in my environment using mesh-transformer-jax.


Also, this time I’ll be running it in the Docker Desktop environment on Windows 11, as described in the following article. The GPU we are using is NVIDIA RTX 3090.


Setting up an environment with a TensorFlow container

This time we will use NVIDIA’s TensorFlow container. clone mesh-transformer-jax and create a folder to work in as folder D:\work\gpt-j.
To start the TensorFlow container, we need to set up port forwarding and mount the folder to use Jupyter Notebook.

docker run --gpus all -it -p 8888:8888 -v D:\work\gpt-j:/gpt-j nvcr.io/nvidia/tensorflow:21.12-tf2-py3 bash

Once started, clone mesh-transformer-jax.

cd /gpt-j
git clone https://github.com/kingoflolz/mesh-transformer-jax.git
cd mesh-transformer-jax

For some reason, the tensorflow package written in this requirements.txt specifies the cpu version, so modify the requirements.txt to specify the normal tensorflow package. Also, install the jaxlib that supports cuda.

sed -e 's/tensorflow-cpu/tensorflow.a/' requirements.txt > new_requirements.txt
pip install -r new_requirements.txt
pip install jax==0.2.12
pip install jaxlib==0.1.68+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

Next, download the parameters for the slim version of GPT-J-6B and extract them.

wget -c https://mystic.the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
tar -I zstd -xf step_383500_slim.tar.zstd

Check the operation of GPT-J.

Run resharding_example.py to see if GPT-J works.

sed -e 's/infer("EleutherAI is")/print(infer("EleutherAI is"))/' resharding_example.py > resharding_example2.py
python resharding_example2.py

Execution results (excerpt)

completion done in 98.40319633483887s
[' a single player, strategy game ....

The first inference takes a while, but it works if the output is something like “EleutherAI is” and the string after that.

I’ll try different things.

When trying other prompts, Jupyter Notebook is more convenient. Let’s try it by rewriting the infer in the last line of resharding_example2.py created above.

top_p = 0.9
temp = 1

context = '''私は真実を答える賢い質問応答ボットです。 
Q: 日本の人口は?
A: 1.2億人です。
Q: 世界で一番人口が多い国は?
A: '''

print(infer(top_p=top_p, temp=temp, gen_len=64, context=context)[0])

Execution results (excerpt)

Q: 日本の人口は?
A: 1.2億人です。
Q: 世界で一番人口が多い国は?
completion done in 9.850934267044067s
Q: 欧州で人口が多い国は?
A: 英国です。
Q: 経済力が優れている国は?

at the end

It’s fun to run GPT-J in my environment. The memory used is just barely enough even with the RTX3090, so it might be better to choose a GPU with more memory.

This time, I tested it on an HP gaming PC as an experiment, but if you want to run Azure, etc., the following VM may be better.

  • Standard_ND40rs_v2 V100 32
  • Standard_ND96asr_v4 A100 40
  • Standard_ND96asr_v4 A100 80

Azure OpenAI is still in preview status, so I hope it will be available for general use soon.