LLMs have opened up a whole new way of doing tricky NLP tasks that I haven't seen documented in other places yet.
Recently I had a need to disentangle a text field that was a messy combination of dates (written out in words) and a place of birth. For example Fourteenth of June 1956 (8:05am), Sacred Heart Hospital, Seattle
. I wanted to extract the date and the place of birth separately.
I could have used a regex to extract the date, but as soon as the data didn't fit the pattern everything would fall to bits. I was dealing with 100,000+ rows of messy manually entered data, so I wanted to use a model that could handle that. I tried using Spacy but it didn't work very well, and I would have to write a lot of rules to get it to almost work.
I decided to try using a LLM (Large Language Model) to do the task. I and the whole world have had nothing but LLMs in our faces for the last year, so I was curious to see if they would work in this context. It turns out that they worked really well, and made the task super easy!
The first thing I tried was to prompt the model with the text I wanted to extract. I used the Hugging Face Transformers library to do this. I used the "meta-llama/Llama-2-7b-chat-hf" model which is the smallest of the Llama-2 models.
My environment for this experiment was a paperspace gradient notebook usually running with a P6000 GPU.
First steps - install the required libraries and download the model.
xxxxxxxxxx
%%capture
!pip install torch tokenizers transformers accelerate pydantic accelerate bitsandbytes faker --upgrade
!pip install num2words peft datasets --upgrade
xxxxxxxxxx
import torch
import textwrap
import transformers
from huggingface_hub import notebook_login
notebook_login()
xxxxxxxxxx
VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…
xxxxxxxxxx
model = "meta-llama/Llama-2-7b-chat-hf"
pipeline = transformers.pipeline(
"text-generation",
model=model,
torch_dtype=torch.float16,
device_map="auto",
)
xxxxxxxxxx
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
The last line obviously does a lot of lifting. It downloads the model from the huggingface hub, loads it into memory, and sets up a pipeline for generating text. The pipeline is a wrapper around the model that makes it easy to use. It takes care of tokenizing the input text, running it through the model, and decoding the output.
At this point, we can start generating text. Let's try it out with a simple prompt.
xxxxxxxxxx
prompt = "Fourteenth of June 1956 (8:05am), Sacred Heart Hospital, Seattle"
output = pipeline(
prompt,
top_k=20,
do_sample=True,
num_return_sequences=1,
eos_token_id=pipeline.tokenizer.eos_token_id,
max_new_tokens=100,
return_full_text=False
)
print(output)
xxxxxxxxxx
[{'generated_text': ", Washington, USA.\n\nThe baby boy was born weighing 7 lbs 6 oz and measuring 20 inches long. His mother, Mary, was 25 years old at the time of his birth. His father, John, was 28 years old. The baby was named Michael after his father.\n\nThe birth certificate also includes information about the baby's ancestry, with his mother listed as having Irish and English ancestry and his"}]
The model has used the model as the start of a story. Not bad, but no good for us. We need to steer the model in the right direction but showing it some example inputs and outputs. We can do this by adding some extra text to the prompt. This is a type of prompt engineering called "Few-shot in-context learning", and most importantly it requires no training!
There are some special tags which we can use to craft our prompt. These tags are specific to the model, so in this case we're using tags specific to the llama-2 model family. I'm not completely sure whether I'm using them correctly, but they seem to work. Each example starts with a <s>
tag and ends with a </s>
tag. We differenciate between what the user input and what the model output by wrapping the user input in [INST]
and [/INST]
tags. Of course in reality, we generated all of the examples, so in a sense we're tricking the model to think that the examples are previous parts to the conversation.
Important to note is that each input to the model is a blank slate, so for every query we show the model the whole set of examples all over again.
xxxxxxxxxx
def make_prompt_for_input(query_str):
return textwrap.dedent(f"""
<s>[INST] <<SYS>>
Act as a natural language processing API.
You are to extract the required text from the document in question and return the answer and no more.
The queries that you will be asked to parse are usually a concatenation of a date, an optional time, and an address
You are to return the date, and time if there is one
You are to return an ISO 8601 formatted date and time, and the address split by ';'
Sometimes either the date, or address will be missing. In that case return 'NONE' for that part of the result.
<</SYS>>
[/INST]
<s>[INST]
Query: Twenty fifth August 1971. Liverpool Maternity Hospital. Liverpool
Response: [/INST]1971-08-25; Liverpool Maternity Hospital, Liverpool</s>
<s>[INST]
Query: ELEVENTH NOVEMBER. 1971 GENERAL HOSPITAL ST. PAUL'S WING. London
Response: [/INST]1971-11-11; GENERAL HOSPITAL ST. PAUL'S WING, London</s>
<s>[INST]
Query: SECOND October 1960 Royal Infirmary, Westminister
Response: [/INST]1960-10-2; Royal Infirmary, Westminister</s>
<s>[INST]
Query: Eighteenth June , 1971 11:05am. City Hospital, Chester
Response: [/INST]1971-06-18T11:05:00Z; City Hospital, Chester</s>
<s>[INST]
Query: Upton Hospital, Ugley
Response: [/INST]NONE; Upton Hospital, Ugley</s>
<s>[INST]
Query: {query_str}
Response: [/INST]""")
examples = [
"First September,1971 (6:24am) Mayday Hospital, Croydon",
"Fourth December 1980, Flat 6, Joyce springs, South Henry, E5T 4NW",
"Twelfth MARCH 1979 1.30 p.m. Grace flats, Elaineport",
"Fifteenth MAY 1975, 329 Sharp river, Lukechester",
"(6) Eighth SEPTEMBER 1976, OXO building, London",
]
input_prompts = [make_prompt_for_input(ex) for ex in examples]
output = pipeline(
input_prompts,
top_k=20,
do_sample=True,
num_return_sequences=1,
eos_token_id=pipeline.tokenizer.eos_token_id,
max_new_tokens=100,
return_full_text=False
)
for item in output:
print(item[0]['generated_text'])
xxxxxxxxxx
1971-09-01T06:24:00Z; Mayday Hospital, Croydon
1980-12-04T00:00:00Z; Flat 6, Joyce springs, South Henry, E5T 4NW
1979-03-12T13:30:00Z; Grace flats, Elaineport
1975-05-15T00:00:00Z; 329 Sharp River, Lukechester
1976-09-08T00:00:00Z; OXO building, London
Amazing! With a very small amount of code, we've got a very robust model that will happily accept out of domain inputs and return the correct answer. It would take a very long time to craft a set of rules that would do the same thing.
But the language of the model is numbers, why should we limit ourselves to telling what the model to do in text? This is the idea behind the next version of the experiment, prompt tuning.
The idea behind prompt tuning is that we can steer the model by prepending a vector to the input text. We freeze the model and learn the vector to produce the desired output. I think about it like gradient descent powered prompt engineering. It's much faster than training the whole model, and also means we can train multiple task specific vectors for the same model and swap them in and out as needed. However, because we're going to be optimising the vector we need some training data.
As usual, in this case getting perfect training data is a catch-22 situation where if I had it, I wouldn't need to do this experiment. So I'm going to use some fake data generated by the Faker library. I'm going to generate 100,000 examples of the type of text I want to extract, and then use the same examples as the training data for the model. To make the data look a bit more realistic, I'll rough it up a little by adding some errors here and there. It's not perfect, but we'll see that it still works pretty well for in-domain and out-of-domain inputs.
Also, for anyone looking to copy this notebook, at this point I restarted it to clear out some memory. I'm sure there is a better way of this this somehow
xxxxxxxxxx
import os
import random
from datetime import datetime, timedelta
import faker
import torch
from tqdm.auto import tqdm
from num2words import num2words
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup, LlamaForCausalLM
from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType, LoraConfig
fake = faker.Faker('en_gb')
We'll start this time by creating our synthetic dataset. The goal is to make a somewhat realistic dataset that is at least illustrative of the task at hand. Generating synthetic data is usually easier said than done, especially when you have to worry about conditional distributions of dependent variables. There has been a number of times where I've had to stop myself falling into a rabbit hole by realising that generating the synthetic data is actually just as much if not more work than whatever alternative I had been avoiding. But in this case it's relatively straight forward.
xxxxxxxxxx
def generate_address():
"""
Generates strings that look like 'Taylor walk Hospital Jaynechester'
"""
address = fake.street_address().split('\n')
if len(address) > 1:
# The first line of the address starts with something like 'Studio 33', discard this
address = address[1:]
if random.random() < 0.5:
# A lot of people are born in hospitals
address.append('Hospital')
address.append(fake.city())
return ' '.join(address)
def format_date(dt: datetime):
"""
Formats a date so that it looks like 'twenty third January 1990'
"""
return "{day} {month} {year}".format(
day=num2words(dt.day, ordinal=True).replace('-', ' '),
month=dt.strftime("%B"),
year=dt.year
)
def random_datetime(start_year, end_year):
"""
Generates a date between start year and end year (inclusive)
"""
start_date = datetime(start_year, 1, 1)
end_date = datetime(end_year + 1, 1, 1)
offset = random.randint(0, (end_date - start_date).days)
return start_date + timedelta(days=offset)
def randomly_delete_spaces(string, deletion_probability=0.05):
indices = [i for i, char in enumerate(string) if char == " "]
deleted_indices = [idx for idx in indices if random.random() <= deletion_probability]
modified_string = "".join([char for i, char in enumerate(string) if i not in deleted_indices])
return modified_string
dataset = {
'input_text': [],
'label_text': [],
}
for i in range(5000):
date = random_datetime(1900, 2000)
address = generate_address()
input_text = randomly_delete_spaces(f'{format_date(date)} {address}')
label = f'{date.isoformat()};{address}'
dataset['input_text'].append(input_text)
dataset['label_text'].append(label)
dataset = Dataset.from_dict(dataset).train_test_split(test_size=0.2, shuffle=True)
# Print a couple of examples example input-output pair
for i in range(3):
print(dataset['test'][i])
xxxxxxxxxx
{'input_text': 'sixteenth November 1945 9 Christine villages Samanthaberg', 'label_text': '1945-11-16T00:00:00;9 Christine villages Samanthaberg'}
{'input_text': 'twenty sixth January 1989 Ingram points Hospital Robertsland', 'label_text': '1989-01-26T00:00:00;Ingram points Hospital Robertsland'}
{'input_text': 'fourteenth April 1952 972 Davies hill Craigmouth', 'label_text': '1952-04-14T00:00:00;972 Davies hill Craigmouth'}
Next we have to do some preprocessing on the these examples so that they are ready to be fed into the language model
What we'll feed into the model is 3 vectors
The first vector is input_ids which is the tokenized input + output padded out to max_length
The second vector is attention_mask which is 0s for the length of the input + padding, and 1s for the length of the desired output. While training, the output will be masked so that the model can't cheat.
The third is the desired output labels
xxxxxxxxxx
model_name_or_path = "bigscience/bloomz-560m"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
text_column = 'input_text'
label_column = 'label_text'
max_length = 64
batch_size = 10
# This preprocess function builds the 3 vectors mentioned above for each example
def preprocess_batch(examples):
batch_size = len(examples[text_column])
# tokenize the inputs and expected outputs
input_texts = [f"{text_column}: {x}; label: " for x in examples[text_column]]
tokenized_inputs = tokenizer(input_texts)
tokenized_labels = tokenizer(examples[label_column])
for i in range(batch_size):
# Get the tokenized input IDs for the current input, and padded label text
sample_input_ids = tokenized_inputs["input_ids"][i]
label_input_ids = tokenized_labels["input_ids"][i] + [tokenizer.pad_token_id]
# Concatenate the input IDs and label IDs and calculate padding required
concat_input_ids = sample_input_ids + label_input_ids
padded_length = max_length - len(concat_input_ids)
#Pad the input IDs
tokenized_inputs["input_ids"][i] = [tokenizer.pad_token_id] * padded_length + concat_input_ids
# Create the attention mask with 0s for padding tokens and 1s for input and label tokens
tokenized_inputs["attention_mask"][i] = [0] * padded_length + [1] * len(concat_input_ids)
# Pad the label IDs with -100 (ignore token) and input IDs with padding tokens
tokenized_labels["input_ids"][i] = [-100] * (padded_length + len(sample_input_ids)) + label_input_ids
# Trasform the vectors to pytorch tensors and trim them to max_length
tokenized_inputs["input_ids"][i] = torch.tensor(tokenized_inputs["input_ids"][i][:max_length])
tokenized_inputs["attention_mask"][i] = torch.tensor(tokenized_inputs["attention_mask"][i][:max_length])
tokenized_labels["input_ids"][i] = torch.tensor(tokenized_labels["input_ids"][i][:max_length])
tokenized_inputs["labels"] = tokenized_labels["input_ids"]
return tokenized_inputs
# this batch size refers only to the batching of the map function
# there is nothing of dimension size 1000 in the final output shape
processed_datasets = dataset.map(
preprocess_batch,
batched=True,
batch_size=1000,
num_proc=1,
remove_columns=dataset["train"].column_names,
load_from_cache_file=False,
desc="Running tokenizer on dataset",
)
train_dataset = processed_datasets["train"]
test_dataset = processed_datasets["test"]
train_dataloader = DataLoader(train_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)
test_dataloader = DataLoader(test_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)
xxxxxxxxxx
Running tokenizer on dataset: 0%| | 0/4000 [00:00<?, ? examples/s]
xxxxxxxxxx
Running tokenizer on dataset: 0%| | 0/1000 [00:00<?, ? examples/s]
Now that we have some synthetic data, and preprocessed it, we can put together our prompt tuning model
I'm going to use a smaller model as the base since it seems to be more stable for this workflow
Using huggingface PEFT (parameter efficient fine-tuning), how we do this is by creating a new instance of the model with some extra bits as defined in the PromptTuningConfig. If we wanted to fine-tune a model using LORA, its the set up is very similar.
xxxxxxxxxx
base_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
# Create a PEFT config
peft_config = PromptTuningConfig(
task_type=TaskType.CAUSAL_LM,
prompt_tuning_init=PromptTuningInit.TEXT,
num_virtual_tokens=8,
# This parameter provides the initial first guess for the vector values
prompt_tuning_init_text="Segment the text into a datetime and address",
tokenizer_name_or_path=model_name_or_path,
)
# Create a new model from the base model with the additions from the peft config
model = get_peft_model(base_model, peft_config)
model.print_trainable_parameters()
model
xxxxxxxxxx
trainable params: 8,192 || all params: 559,222,784 || trainable%: 0.0014648902430985358
xxxxxxxxxx
PeftModelForCausalLM(
(base_model): BloomForCausalLM(
(transformer): BloomModel(
(word_embeddings): Embedding(250880, 1024)
(word_embeddings_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(h): ModuleList(
(0-23): 24 x BloomBlock(
(input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(self_attention): BloomAttention(
(query_key_value): Linear(in_features=1024, out_features=3072, bias=True)
(dense): Linear(in_features=1024, out_features=1024, bias=True)
(attention_dropout): Dropout(p=0.0, inplace=False)
)
(post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): BloomMLP(
(dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)
(gelu_impl): BloomGelu()
(dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
(ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=1024, out_features=250880, bias=False)
)
(prompt_encoder): ModuleDict(
(default): PromptEmbedding(
(embedding): Embedding(8, 1024)
)
)
(word_embeddings): Embedding(250880, 1024)
)
We now have a model that functions almost the same as the base Bloom model, except with the addition of 8 trainable 'tokens' that will be prepended to the output. These will steer the model in the right direction.
Time to train the model! First initialise the optimizer and learning rate scheduler
xxxxxxxxxx
num_epochs = 5
lr = 3e-2
checkpoint_name = f"dates_and_addresses_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}_v1.pt".replace(
"/", "_"
)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=(len(train_dataloader) * num_epochs),
)
This next bit is a fairly standing modeling training loop, and can probably be re-written more elegantly using the huggingface Trainer interface
xxxxxxxxxx
device = "cuda"
model = model.to(device)
for epoch in range(num_epochs):
model.train()
total_loss = 0
for step, batch in enumerate(tqdm(train_dataloader)):
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.detach().float()
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
model.eval()
eval_loss = 0
eval_preds = []
for step, batch in enumerate(tqdm(test_dataloader)):
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
loss = outputs.loss
eval_loss += loss.detach().float()
eval_preds.extend(
tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
)
eval_accuracy = [y_true in y_pred for y_true, y_pred in zip(dataset['test']['label_text'], eval_preds)]
eval_accuracy = sum(eval_accuracy) / len(eval_accuracy)
eval_epoch_loss = eval_loss / len(test_dataloader)
eval_ppl = torch.exp(eval_epoch_loss)
train_epoch_loss = total_loss / len(train_dataloader)
train_ppl = torch.exp(train_epoch_loss)
print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=} {eval_accuracy=}")
xxxxxxxxxx
epoch=0: train_ppl=tensor(2.3170, device='cuda:0') train_epoch_loss=tensor(0.8403, device='cuda:0') eval_ppl=tensor(1.1796, device='cuda:0') eval_epoch_loss=tensor(0.1652, device='cuda:0') eval_accuracy=0.751
xxxxxxxxxx
epoch=1: train_ppl=tensor(1.1433, device='cuda:0') train_epoch_loss=tensor(0.1339, device='cuda:0') eval_ppl=tensor(1.1117, device='cuda:0') eval_epoch_loss=tensor(0.1059, device='cuda:0') eval_accuracy=0.827
xxxxxxxxxx
epoch=2: train_ppl=tensor(1.1021, device='cuda:0') train_epoch_loss=tensor(0.0972, device='cuda:0') eval_ppl=tensor(1.0970, device='cuda:0') eval_epoch_loss=tensor(0.0926, device='cuda:0') eval_accuracy=0.842
xxxxxxxxxx
epoch=3: train_ppl=tensor(1.0865, device='cuda:0') train_epoch_loss=tensor(0.0830, device='cuda:0') eval_ppl=tensor(1.0845, device='cuda:0') eval_epoch_loss=tensor(0.0811, device='cuda:0') eval_accuracy=0.857
xxxxxxxxxx
epoch=4: train_ppl=tensor(1.0706, device='cuda:0') train_epoch_loss=tensor(0.0683, device='cuda:0') eval_ppl=tensor(1.0716, device='cuda:0') eval_epoch_loss=tensor(0.0692, device='cuda:0') eval_accuracy=0.87
Finally, we can check some output of the prompt tuning model
xxxxxxxxxx
inputs = [
f'{text_column}: sixteenth October 1945 Marion summit Donnaborough; label: ',
f'{text_column}: fourth may 1966 65 Ian landNormanmouth; label: ',
f'{text_column}: (1) tenth september 1911Carole points Hospital Popefurt; label: ',
f'{text_column}: on or about twenth fourth October 1945 Marion summit Donnaborough; label: ',
f'{text_column}: twenty sixth January 1989; label: ',
f'{text_column}: 1 Carole points Hospital Popefurt; label: ',
f'{text_column}: fourthDecember1966LydiaislePortJennamouth; label: ',
]
inputs = tokenizer(inputs, return_tensors='pt', padding=True)
model.to(device)
with torch.no_grad():
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model.generate(**inputs, max_new_tokens=20, eos_token_id=3)
outputs = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)
for output in outputs:
print(output)
xxxxxxxxxx
input_text: sixteenth October 1945 Marion summit Donnaborough; label: 1945-10-16T00:00:00;Marion summit Donnaborough
input_text: fourth may 1966 65 Ian landNormanmouth; label: 1966-05-04T00:00:00;65 Ian land Normanmouth
input_text: (1) tenth september 1911Carole points Hospital Popefurt; label: 1911-09-10T00:00:00;Carole points Hospital Popefurt
input_text: on or about twenth fourth October 1945 Marion summit Donnaborough; label: 1945-10-24T00:00:00;Marion summit Donnaborough
input_text: twenty sixth January 1989; label: 1989-01-26T00:00:00;2100-1991-11-21 Nauru Island Nauru IslandAustin, Texas,
input_text: 1 Carole points Hospital Popefurt; label: 1951-01-31T00:00:00;Carole points Hospital Popefurt
input_text: fourthDecember1966LydiaislePortJennamouth; label: 1966-12-04T00:00:00;Lydiaisle Port Jennamouth
Not too bad!
Although I was hoping for more than 0.87 accuracy on the test set, I think this method makes up for that by showing what it can do with out of domain examples above. This isn't going to replace rules based approches completely, not least because of the dramatically higher resource requirements, but it is a useful tool to have in the box.
To extend this further, the next step might be to look at LORA fine-tuning. This is where you are training a very small proportion of the overall model, and will probably lead to better results for these tasks. As with all of these approaches, better results probably mean more training time.
Below are some useful resources I found while writing this
https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/
https://huggingface.co/docs/peft/task_guides/clm-prompt-tuning
I also found this snippet useful to see memory estimates for models for different quantisation levels
xxxxxxxxxx
!accelerate estimate-memory "meta-llama/Llama-2-7b-chat-hf"
xxxxxxxxxx
Loading pretrained config for `meta-llama/Llama-2-7b-chat-hf` from `transformers`...
┌────────────────────────────────────────────────────────┐
│Memory Usage for loading `meta-llama/Llama-2-7b-chat-hf`│
├───────┬─────────────┬──────────┬───────────────────────┤
│ dtype │Largest Layer│Total Size│ Training using Adam │
├───────┼─────────────┼──────────┼───────────────────────┤
│float32│ 776.03 MB │ 24.74 GB │ 98.96 GB │
│float16│ 388.02 MB │ 12.37 GB │ 49.48 GB │
│ int8 │ 194.01 MB │ 6.18 GB │ 24.74 GB │
│ int4 │ 97.0 MB │ 3.09 GB │ 12.37 GB │
└───────┴─────────────┴──────────┴───────────────────────┘
And another handy one to diagnose memory issues
xxxxxxxxxx
print(torch.cuda.memory_summary(device=None, abbreviated=True))
Note on how I generated this this page
This blog is hosted in jekyll. I don't post often but whenever I do I find that the set of tools I use has usually moved on a bit from when I last posted, and integrating the new workflow into jekyll seems to be a pain every time. This time, I wasn't happy to just post a plain jupyter notebook because I didn't like the way it looked. I wanted it to be rendered like a Github markdown page. I spent some time trying to export the jupyter notebook as a markdown and style it directly in jekyll, but eventually gave up. The way I found was to export it into markdown, open that markdown file in Typora, apply the default Github theme and then export it as HTML with styles. It feels like there should be an easier way...