Skip to main content

How to use Jina embeddings in Elixir with Bumblebee

The jina embedding models stand out. Don't let a missing model implementation stop you from realizing your awesome AI project in Elixir.

· 17 min read
Joel Koch

Introduction

It's 2024 and every meaningful project is using some form of RAG based search. So, you want to start your next project with awesome AI capabilities. Of course, you want to use Elixir to get a robust, scalable and maintainable product.

You've heard good things about the embeddings models from Jina AI and want to integrate them into your Elixir project.

You head over to the getting started guides of Bumblebee, adapt it to use the Jina embeddings v2 model and let's go...

repo = {:hf, "jinaai/jina-embeddings-v2-base-en"}

{:ok, model_info} = Bumblebee.load_model(repo)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)

serving =
Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
compile: [batch_size: 64, sequence_length: 512],
defn_options: [compiler: EXLA],
output_attribute: :hidden_state,
output_pool: :mean_pooling
)
** (ArgumentError) could not match the class name "JinaBertForMaskedLM" to any of the supported models, please specify the :module and :architecture options
(bumblebee 0.6.0) lib/bumblebee.ex:434: Bumblebee.do_load_spec/4
(bumblebee 0.6.0) lib/bumblebee.ex:603: Bumblebee.maybe_load_model_spec/3
(bumblebee 0.6.0) lib/bumblebee.ex:591: Bumblebee.load_model/2
#cell:gukmtcrvamdu57kh:3: (file)

Bummer, looks like you can't use the awesome Jina model :(

Do I hear a "nevermind let's build it in Python and cope with the shortcomings"?

Don't do this to yourself. If you stick around I'll show you how to convert the Jina model to Elixir in 3 steps:

  1. Know your model
  2. Implement it
  3. Verify your work

(the fast lane: talk to bitcrowd, we can help you with that)

Know your model

Let's start with the model you want to run. In our case this is the Jina embeddings v2 model that you can find here: https://huggingface.co/jinaai/jina-embeddings-v2-base-en

As with everything in the AI world, there are many smart words on the model page and you'd need quite some time to learn about everything.

We're doing a shortcut here: It's kind of a secret to people not familiar with the topic, but actually most of the new models are based on some existing model. So, the first step is to identify the base architecture on the model page...

jina-embeddings-v2-base-en is an English, monolingual embedding model supporting 8192 sequence length. It is based on a BERT architecture (JinaBERT) that supports the symmetric bidirectional variant of ALiBi to allow longer sequence length.

Alright, now we know that the Jina model we want to run is based on BERT, with the addition of something they call ALiBi.

You can read more about ALiBi here.

In essence, it's replacing position embeddings with a simpler linear bias. The author was kind enough to add instructions how to use ALiBi:

The implementation is very simple.

  1. Remove the position embeddings from the model
  2. Set up the relative bias matrix
  3. Add the bias matrix to the mask, which is then added in each attention score computation
  4. (This might not be necessary in other frameworks.) Move the mask computation to before the layer loop, to make the transformer a tiny bit faster
Why Jina embeddings v2?

Jina AI has introduced a new generation of text embedding models called jina-embeddings-v21, which offer several key benefits:

  • Extended Context Length: These models uniquely support an 8K (8192 tokens) context length. This means that files span over less chunks. Thus, code files often fit in a single chunk.
  • High Performance: The models rank among the top performers on HuggingFace's MTEB leaderboard for embedding models, especially considering their small size and the extended context length. 2
  • Open Source: Jina AI has made these models open source, democratizing access to high-performance embedding technology that was previously only available through proprietary models like OpenAI's GPT-4. 3

These features make the Jina embeddings v2 a great choice, especially for code search and multi-lingual documents. It also avoids the caveats of using OpenAI's models. bitcrowd is striving to make local AI technology accessible.

Debugging the problem

As a short sidetrack, let's have a look at what happened when we couldn't run our model in Bumblebee.

This is the relevant stacktrace:

** (ArgumentError) could not match the class name "JinaBertForMaskedLM" to any of the supported models, please specify the :module and :architecture options
(bumblebee 0.6.0) lib/bumblebee.ex:434: Bumblebee.do_load_spec/4
(bumblebee 0.6.0) lib/bumblebee.ex:603: Bumblebee.maybe_load_model_spec/3
(bumblebee 0.6.0) lib/bumblebee.ex:591: Bumblebee.load_model/2

Following the error, we look into the do_load_spec/4 function at line 434 in lib/bumblebee.ex of our local clone of the Bumblebee repository and we see that there is no model because a few lines above infer_model_type/1 did not return one. Inside infer_model_type/1 we do a lookup of JinaBertForMaskedLM in a @transformers_class_to_model map, which returns nil. This is the root cause of our error.

bumblebee/lib/bumblebee.ex
defp infer_model_type(%{"architectures" => [class_name]}) do
case @transformers_class_to_model[class_name] do
nil ->
{:error,
"could not match the class name #{inspect(class_name)} to any of the supported models"}

{module, architecture} ->
{:ok, module, architecture}
end
end

When we scroll up to look at @transformers_class_to_model we can see that there is a long list of model names that map to tuples of an Elixir module and an atom.

bumblebee/lib/bumblebee.ex
...
"BartForSequenceClassification" => {Bumblebee.Text.Bart, :for_sequence_classification},
"BartModel" => {Bumblebee.Text.Bart, :base},
"BertForMaskedLM" => {Bumblebee.Text.Bert, :for_masked_language_modeling},
...

What's happening here?

You need these three things to run a model:

  1. An implementation of the model as code
  2. The weights of the model
  3. The configuration of the model

The configuration is what Bumblebee looks for on Hugging Face when we tell it that repo = {:hf, "jinaai/jina-embeddings-v2-base-en"}. When you look at that actual repository, you'll find a bunch of config files, including a config.json.

There, you'll see that the architecture is specified as JinaBertForMaskedLM:

jina-embeddings-v2-base-en/config.json
...
"architectures": [
"JinaBertForMaskedLM"
],
...

Bumblebee takes that information and looks into @transformers_class_to_model to check if it has an implementation available. Then, it would proceed to download the corresponding weights and run your model.

In our case, there is no implementation available, and therefore @transformers_class_to_model doesn't have the JinaBertForMaskedLM key.

Alright, now that we know a bit better what's going on, this is what we have to do:

  1. Add a mapping from the JinaBertForMaskedLM key to our implementation
  2. Implement the model in Bumblebee

The difference between BERT and JinaBert

So, if we have a closer look at the @transformers_class_to_model map, we can see that BertForMaskedLM is already supported. This means that there is an existing implementation of BERT in Bumblebee. Since the Jina embeddings v2 model is based on BERT, all we have to do is to figure out the difference between Jina embedding v2 and BERT in their original Python implementations. Then we can go on to add the implementation of that difference to the Elixir implementation of BERT.

We could stare at the code of the Jina BERT implementation and the code of BERT in the transformers library of Hugging Face until we know the ins and outs of both.

Why bitcrowd?

bitcrowd is a team of Elixir enthusiasts that have been building machine learning solutions since 2019. We are passionate about making AI technology accessible and have worked with Python and Elixir setups. The motivation for this blog post is that we see a significant performance boost and a reduction in complexity when using Elixir to run machine learning models.

Or, we can just use Git to get the diff between those files. You can do the same by cloning those two repositories in a directory and running git diff transformers/src/transformers/models/bert/modeling_bert.py jina-bert-implementation/modeling_bert.py.

If you squint your eyes you can see that quite some diffing lines are simple renamings from Bert... to JinaBert..., others are just formatted differently:

diff --git a/transformers/src/transformers/models/bert/modeling_bert.py b/jina-bert-implementation/modeling_bert.py
- from .configuration_bert import BertConfig
+ from .configuration_bert import JinaBertConfig

(...)

- n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ n
+ in [
+ "adam_v",
+ "adam_m",
+ "AdamWeightDecayOptimizer",
+ "AdamWeightDecayOptimizer_1",
+ "global_step",
+ ]

Ignoring all this noise we can find the relevant bits of the implementation. They added ALiBi as we can see here and pass it as bias to the layer module:

jina-bert-implementation/modeling_bert.py
def rebuild_alibi_tensor(
self, size: int, device: Optional[Union[torch.device, str]] = None
):
# Alibi
# Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
# In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
# of the logits, which makes the math work out *after* applying causal masking. If no causal masking
# will be applied, it is necessary to construct the diagonal mask.
n_heads = self.num_attention_heads

def _get_alibi_head_slopes(n_heads: int) -> List[float]:
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]

if math.log2(n_heads).is_integer():
return get_slopes_power_of_2(
n_heads
) # In the paper, we only train models that have 2^a heads for some a. This function has
else: # some good properties that only occur when the input is a power of 2. To maintain that even
closest_power_of_2 = 2 ** math.floor(
math.log2(n_heads)
) # when the number of heads is not a power of 2, we use this workaround.
return (
get_slopes_power_of_2(closest_power_of_2)
+ _get_alibi_head_slopes(2 * closest_power_of_2)[0::2][
: n_heads - closest_power_of_2
]
)

context_position = torch.arange(size, device=device)[:, None]
memory_position = torch.arange(size, device=device)[None, :]
relative_position = torch.abs(memory_position - context_position)
# [n_heads, max_token_length, max_token_length]
relative_position = relative_position.unsqueeze(0).expand(n_heads, -1, -1)
slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device) * -1
alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position
# [1, n_heads, max_token_length, max_token_length]
alibi = alibi.unsqueeze(0)
assert alibi.shape == torch.Size([1, n_heads, size, size])

self._current_alibi_size = size
return alibi

(...)

self.register_buffer(
"alibi",
self.rebuild_alibi_tensor(size=config.max_position_embeddings),
persistent=False,
)

(...)

alibi_bias = self.alibi[:, :, :seqlen, :seqlen]

(...)

layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
alibi_bias, # <-- here
past_key_value,
output_attentions,
)

Moreover, they added a JinaBertGLUMP class, set it as mlp, and use that instead of an intermediate and an output layer:

jina-bert-implementation/modeling_bert.py
class JinaBertGLUMLP(nn.Module):
def __init__(self, config: JinaBertConfig):
super().__init__()
self.config = config
self.gated_layers = nn.Linear(
config.hidden_size, config.intermediate_size * 2, bias=False
)
if config.feed_forward_type == 'reglu':
self.act = nn.ReLU()
elif config.feed_forward_type == 'geglu':
self.act = nn.GELU()
else:
raise ValueError(
f"feed_forward_type {config.feed_forward_type} not supported"
)
self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual_connection = hidden_states
# compute the activation
hidden_states = self.gated_layers(hidden_states)
gated = hidden_states[:, :, : self.config.intermediate_size]
non_gated = hidden_states[:, :, self.config.intermediate_size :]
hidden_states = self.act(gated) * non_gated
hidden_states = self.dropout(hidden_states)
# multiply by the second matrix
hidden_states = self.wo(hidden_states)
# add the residual connection and post-LN
hidden_states = self.layernorm(hidden_states + residual_connection)
return hidden_states

(...)

if self.feed_forward_type.endswith('glu'):
self.mlp = JinaBertGLUMLP(config)

(...)

if self.feed_forward_type.endswith('glu'):
layer_output = self.mlp(attention_output)
else:
# this is the code called in the regular BERT implementation
self.intermediate = JinaBertIntermediate(config)
self.output = JinaBertOutput(config)

So, with all that information, we are going to implement it in Elixir!

Implement it

We know that we want to base our implementation on BERT, so let's take our local clone of Bumblebee, make a copy of the BERT implementation and rename it to JinaBert.

bumblebee/lib/bumblebee/text/jina_bert.ex
defmodule Bumblebee.Text.JinaBert do
...
end

To make Bumblebee find our implementation, we must add it to @transformers_class_to_model.

bumblebee/lib/bumblebee.ex
...
"JinaBertForMaskedLM" => {Bumblebee.Text.JinaBert, :for_masked_language_modeling},
"JinaBertModel" => {Bumblebee.Text.JinaBert, :base},
...

Now, when you try to run the model, you shouldn't get any error, but only a bunch of warnings about unused parameters. That's because our implementation is still the regular BERT implementation, but we download and apply weights for the Jina embeddings v2 model. Therefore, there is a mismatch between the parameters and the implementation.

However, we can use the warnings for our purpose. Since we want to apply all of the weights of the Jina model, we know that we are not done until all of the warnings disappeared.

If you want to jump ahead you can find the commit of the implementation here, but I'd appreciate if you stick around.

Discuss your AI project with us

bitcrowd offers a free consultation for your AI project. We are a team of Elixir, Python and machine learning experts. We can help you with your project in any stage.

Let's start with computing the ALiBi matrix since this is what they've called out as the major difference. To keep it as simple as possible, we closely follow the Python version.

bumblebee/lib/bumblebee/text/jina_bert.ex
defp get_slopes_power_of_2(n) do
start = 2 ** -(2 ** -(:math.log2(n) - 3))
ratio = start
for i <- 0..(n - 1), do: start * ratio ** i
end

defp integer?(number) do
round(number) == number
end

defp get_alibi_head_slopes(n_heads) do
if integer?(:math.log2(n_heads)) do
get_slopes_power_of_2(n_heads)
else
closest_power_of_2 = 2 ** round(:math.floor(:math.log2(n_heads)))
get_slopes_power_of_2(closest_power_of_2) ++
(get_alibi_head_slopes(2 * closest_power_of_2)
|> Enum.take_every(2)
|> Enum.take(n_heads - closest_power_of_2))
end
end

defp alibi_matrix(num_attention_heads, size) do
context_position = Nx.iota({1, size, 1}, axis: 1)
memory_position = Nx.iota({1, size}, axis: 1)
relative_position = Nx.abs(Nx.subtract(context_position, memory_position))
relative_position = Nx.tile(relative_position, [num_attention_heads, 1, 1])

slopes = Nx.tensor(get_alibi_head_slopes(num_attention_heads)) |> Nx.multiply(-1)

slopes
|> Nx.new_axis(-1)
|> Nx.new_axis(-1)
|> Nx.multiply(relative_position)
|> Nx.new_axis(0)
end

We could probably make this code nicer and lean more into the Elixir way of doing things, but for now I'd say this is fine.

Oh, and we must use the ALiBi matrix as bias, of course.

bumblebee/lib/bumblebee/text/jina_bert.ex
alibi_relative_bias_matrix =
Axon.nx(hidden_state, fn hidden_state ->
{_, seqlen, _} = Nx.shape(hidden_state)
matrix = alibi_matrix(spec.num_attention_heads, spec.max_positions)
matrix[[.., .., 0..(seqlen - 1), 0..(seqlen - 1)]]
end)

Layers.Transformer.blocks(
hidden_state,
[
attention_mask: attention_mask,
attention_head_mask: attention_head_mask,
attention_relative_bias: alibi_relative_bias_matrix, # <- here it is
cache: cache,
causal: decoder?,
...

... and remove the position embeddings as described by the author of ALiBi.

Next on our list is the JinaBertGLUMLP class.

In Elixir, we implement that simply as a function. Again, we'll stick closely to the Python implementation and naming to keep it aligned.

bumblebee/lib/bumblebee/text/jina_bert.ex
def glumlp(
hidden_states,
spec,
opts
) do
name = opts[:name]
intermediate_size = spec.intermediate_size
activation = spec.activation
hidden_dropout_prob = spec.dropout_rate
hidden_size = spec.hidden_size
layer_norm_eps = spec.layer_norm_epsilon

residual_connection = hidden_states

hidden_states =
hidden_states
|> Axon.dense(intermediate_size * 2, use_bias: false, name: join(name, "gated_layers"))

gated =
Axon.nx(hidden_states, fn hidden_states ->
hidden_states[[.., .., 0..(intermediate_size - 1)]]
end)
|> Axon.activation(activation)

non_gated =
Axon.nx(hidden_states, fn hidden_states ->
hidden_states[[.., .., intermediate_size..-1//1]]
end)

hidden_states =
Axon.multiply(gated, non_gated)
|> Axon.dropout(rate: hidden_dropout_prob)
|> Axon.dense(hidden_size, name: join(name, "wo"))

hidden_states
|> Axon.add(residual_connection)
|> Axon.layer_norm(epsilon: layer_norm_eps, name: join(name, "layernorm"))
end

We also need a custom implementation of the transformers block. In JinaBert we don't add the shortcut and normalize the output after computing the feed-forward network function. It's actually only 2 lines that changed in comparison to the standard block_impl. But since it's a custom implementation we must create a function.

bumblebee/lib/bumblebee/text/jina_bert.ex
defp jina_block_impl(hidden_state, steps, _name) do
shortcut = hidden_state

{hidden_state, attention_info} = steps.self_attention.(hidden_state)

hidden_state =
hidden_state
|> Axon.add(shortcut)
|> steps.self_attention_norm.()

{hidden_state, cross_attention_info} =
steps.cross_attention_maybe.(hidden_state, fn hidden_state ->
shortcut = hidden_state

{hidden_state, cross_attention_info} = steps.cross_attention.(hidden_state)

hidden_state =
hidden_state
|> Axon.add(shortcut)
|> steps.cross_attention_norm.()

{hidden_state, cross_attention_info}
end)

hidden_state =
hidden_state
|> steps.ffn.()
# here we removed two lines
# |> Axon.add(shortcut)
# |> steps.output_norm.()


{hidden_state, attention_info, cross_attention_info}
end

Then, we pass the glumlp and jina_block_impl functions to our transformer blocks.

bumblebee/lib/bumblebee/text/jina_bert.ex
Layers.Transformer.blocks(
hidden_state,
[
...
ffn: &glumlp(&1, spec, name: &2),
block_type: &jina_block_impl/3,
...

At the end of each Bumblebee module file there is a mapping between the names of the layers in Elixir and Python. This way, each layer gets the correct parameters from the weights dowloaded from Hugging Face.

We must add the mapping of the names of the layers we added ({n} is a placeholder for a number).

bumblebee/lib/bumblebee/text/jina_bert.ex
defimpl Bumblebee.HuggingFace.Transformers.Model do
def params_mapping(_spec) do
%{
...
"encoder.blocks.{n}.ffn.wo" => "encoder.layer.{n}.mlp.wo",
"encoder.blocks.{n}.ffn.layernorm" => "encoder.layer.{n}.mlp.layernorm",
"encoder.blocks.{n}.ffn.gated_layers" => "encoder.layer.{n}.mlp.gated_layers"
}
end
end

And we are done.

Verify it

After implementing the model we should verify that it's actually correct.

We reimplemented the Python code, so this is our plan to verify the model:

  1. run the model in Python for a given input
  2. capture the output
  3. run the model in Elixir for the same input
  4. compare the output with the previously captured one

You can get more elaborated than this and for instance run a number of random inputs and compare the corresponding outputs. I have some more thoughts about testing, but that's for another blog post. For now a single input-output pair is sufficient.

This is the Python script we can use to get the output.

from transformers import AutoModel

model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
embedding = model.encode('How is the weather today?')
print(embedding[0:3])

# prints [-0.34827134 -0.60091823 0.60223615]

We only look at the first three numbers of the embedding for now.

Next, we run the same input through our Elixir model.

repo =  {:hf, "jinaai/jina-embeddings-v2-base-en"}

{:ok, %{model: model, params: params, spec: spec} = model_info} =
Bumblebee.load_model(repo,
params_filename: "model.safetensors",
spec_overrides: [architecture: :base]
)

{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)

serving =
Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
compile: [batch_size: 2, sequence_length: 512],
defn_options: [compiler: EXLA],
output_attribute: :hidden_state,
output_pool: :mean_pooling
)

%{embedding: embedding} = Nx.Serving.run(serving, "How is the weather today?")

dbg(embedding[0..2])

# prints
# #Nx.Tensor<
# f32[3]
# [-0.34827110171318054, -0.600918173789978, 0.602236270904541]
# >

In comparison to our first attempt to run the model we have some changes here: We tell Bumblebee to look in model.safetensors for the parameters (otherwise I got an CRC mismatch error that I couldn't really make sense of). We also override the spec to use the :base architecture of our model. This way, we get the latest hidden state as output and can let the text_embedding serving apply mean pooling.

Then, we can run the model.

Note that the output values are slightly different in comparison to the Python version. When working with different frameworks, numbers can be slightly different due to small differences in the implementation.

We can verify that the output values are within a small tolerance using Nx.all_close/3.

Nx.all_close(Nx.tensor([-0.34827134, -0.60091823, 0.60223615]), embedding[0..2])

Reach out to bitcrowd if you want to make awesome Elixir AI projects happen (or really any awesome Elixir project).

References

  1. Günther, M., Ong, J., Mohr, I., Abdessalem, A., Abel, T., Akram, M.K., Guzman, S., Mastrapas, G., Sturua, S., Wang, B., Werk, M., Wang, N., & Xiao, H. (2023). Jina Embeddings 2: 8192-Token General-Purpose Text Embeddings for Long Documents. arXiv:2310.19923.
  2. MTEB Leaderboard
  3. Jina Embeddings v2
  4. Bumblebee
  5. ALiBi
Joel Koch

Joel Koch

Neural Network Navigator

We’re hiring

Work with our great team, apply for one of the open positions at bitcrowd