Behind SKIM: LLM-powered Contract AI
13 Jan 2024In this app, I wanted to simplify a key step of the contract review process that many lawyers may have to go through - that is, sifting through the boilerplate in order to establish exactly what obligations are being established between both parties to a contract. You can try it out here - if you are using it for the first time, do wait a little bit for the API to warm up; I’m working on a college student’s budget right now ):
Put in any contract text in the text box, or as a word document…
Easily filter for the obligation category you are interested in
This seems like a task right up the alley of artificial intelligence - especially with the latest advancements in natural language processing in the form of large language models! In order to do that, I needed to first establish the key elements of the app that I was going to code up. It would have to take any length of not-necessarily-well-structured text, and be able to detect obligations, then classify them into a pre-defined set of categories that frequently appear across most contracts. These problems are not easy at all, and if you chunk them into their constitutent steps, each would comprise a significant amount of inference that would not have been possible until recently.
So, in this blog post, I will explain each major step of the app, showing what happens under the hood, and why I made the design/engineering choice that I did. Enjoy!
Text classification via large language models: the core of this app
The core problem of this app was essentially multi-label classification - that is, if I looked at a sentence, I would need to bin it into one of these categories:
- Termination condition
- Performance obligation
- Payment obligation
- Compliance obligation
- Support or warranty obligation
- Insurance obligation
- Quality obligation
- Audit obligation
- Contract renewal
- Exceptions or exemptions
- Indemnification
- Jurisdiction
- None of the above
A huge problem with multi-label classification is that it’s difficult to be accurate - if we were to train this classifier in a more traditional way, we would not only need to have a lot of training data - we would need to have a lot of training data for each label. That can add up very quickly, and isn’t very easy for an independent programmer to do.
This is why large language models are so useful for this task - they are essentially an extended version of transfer learning, that is, using embeddings already trained on a massive corpus of the internet, and jumping off of that as the starting point for any further training. This is a process known as fine-tuning, and it has been really successful at allowing for training on few pieces of training data.
Before we train any further though, let’s see some of the capabilities of large language models - available for free, off the shelf at Hugging Face - at zero-shot classification. In order to do that, we convert large language model’s key functionality - text to text generation - and use prompt engineering in order to get specific, standardized answers that would be easy to parse within a backend, and thus coerce it into a text classification task. I tested this on Google’s flan-t5-large, and was using a GPU for inference:
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large").to('cuda') # With GPU --> .to('cuda')
"It is expressly understood that Commerce may unilaterally cancel this Contract for Contractor’s refusal to comply with this provision."
input_prompt = f"""
Identify the clause type. If it is none, reply None
Text: "{clause_text}"
Contractual clause types:
1. Termination condition
2. Performance obligation
3. Payment obligation
4. Compliance obligation
5. Support or warranty obligation
6. Insurance obligation
7. Quality obligation
8. Audit obligation
9. Contract renewal
10. Exceptions or exemptions
11. Indemnification
12. Jurisdiction
"""
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to("cuda")
outputs = model.generate(input_ids, max_length=512)
category = tokenizer.decode(outputs[0])
>>> print(category)
"<pad> 1. </s>"
Okay, that’s not too bad! It overall does a really great job at making predictions with zero prior information (provided by me) - but after looking at a full contract being processed there were two issues that stood out to me:
(1) there are a few clauses which are classified wrongly, and (2) the answers are being returned in a format that I think is wrong (I want it to be returned as the text, not numbers)
It’s honestly really impressive that we’ve gotten this far without really needing code (well - we needed code to download the model, but up till this point you could really just have used ChatGPT) - but the work awaits! Let’s fine-tune!
I get the zero-shot classifications and re-classify them by hand (for example, in the text I predicted earlier, I would write “Termination condition” as the desired output), and then I put that back in for fine-tuning:
import json
import pandas as pd
import numpy as np
import torch
from datasets import load_dataset, load_metric, DatasetDict
from transformers import AutoTokenizer, T5Tokenizer
zs = load_dataset('csv', data_files='zeroshot.csv', split="train")
# 90% train, 10% test + validation
train_testvalid = zs.train_test_split(test_size=0.1)
# Split the 10% test + valid in half test, half valid
test_valid = train_testvalid['test'].train_test_split(test_size=0.5)
# gather everyone if you want to have a single DatasetDict
zs = DatasetDict({
'train': train_testvalid['train'],
'test': test_valid['test'],
'valid': test_valid['train']})
model_checkpoint = 'google/flan-t5-large'
tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)
max_input_length = 512 # change to setting later
max_target_length = 512 # change to setting later
def preprocess_data(examples):
inputs = []
for clause_text in examples['clause']:
input_prompt = f"""
Identify the clause type. If it is none, reply None.
Text: "{clause_text}"
Contractual clause types:
1. Termination condition
2. Performance obligation
3. Payment obligation
4. Compliance obligation
5. Support or warranty obligation
6. Insurance obligation
7. Quality obligation
8. Audit obligation
9. Contract renewal
10. Exceptions or exemptions
11. Indemnification
12. Jurisdiction
"""
inputs.append(input_prompt)
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(examples["flan_fewshot_cat"], max_length=max_target_length,
truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
tokenized_datasets = zs.map(preprocess_data, batched=True)
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
model_name = "flan-t5-large-skim"
model_dir = f"bigdata/{model_name}"
args = Seq2SeqTrainingArguments(
model_dir,
evaluation_strategy="steps",
eval_steps=100,
logging_strategy="steps",
logging_steps=100,
save_strategy="steps",
save_steps=200,
learning_rate=4e-5,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=1,
predict_with_generate=True,
fp16=False,
load_best_model_at_end=True,
metric_for_best_model="rouge1",
)
data_collator = DataCollatorForSeq2Seq(tokenizer)
# Function that returns an untrained model to be trained
from transformers import T5Tokenizer, T5ForConditionalGeneration
def model_init():
return T5ForConditionalGeneration.from_pretrained(model_checkpoint)
trainer = Seq2SeqTrainer(
model=model_init(),
args=args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["valid"],
data_collator=data_collator,
tokenizer=tokenizer,
)
trainer.train()
In order to save my work, I save it locally, and also push it to the Hugging Face Hub.
trainer.save_model("bigdata/flan-t5-large-skim") # save model locally
trainer.push_to_hub() # push to hugging face hub
The great thing about pushing it to the Hugging Face Hub is that they have a free inference API that allows you to host the computationally-heavy process on their servers instead of spinning up a server with enough RAM to fit large language models! The only issue is that there are rate limits, and they are not very forthcoming about the explicit limits…
Text preprocessing
Awesome! Now that we have our language model, surely we are done! Not so fast - a key step in designing apps is to understand the way in which your users will interact with your app. If I were a tired lawyer, I would not want to paste my obligations in one by one to check! That is why I designed my app to handle either long-form text being submitted through a text box, or even just a word document. Let’s take a look at some of the code that begins by parsing a word document:
# Extract text from the .docx file
doc = Document(docx_stream)
clause_text = ""
for paragraph in doc.paragraphs:
clause_text += paragraph.text + "\n"
Now, because text can be very unstructured (even if very readable!), I need to take a few text preprocessing steps in order to make sure my model isn’t inferring on garbage: empty strings, accidentally captured metadata, page numbers.
And before that, an even more important step: if you have noticed, my model presumes that I am sending in a single, complete sentence. That means that paragraphs will have to be broken up into sentences. This presents many problems, since rule-based methods, like splitting on periods, may end up accidentally catching periods within sentences (for example, decimal places/systems).
To solve this problem, I use spacy’s sentence boundary detection, which in turn uses dependency parsing to propagate from roots (eg. subject, object) down through their children (eg. adjectives, adverbs), till the end, which is then determined as a sentence boundary.
import spacy
from spacy.cli import download
# Specify the model name
model_name = 'en_core_web_sm'
# Check if the model is already installed
if model_name not in spacy.util.get_installed_models():
# If not installed, download the model
download(model_name)
nlp = spacy.load(model_name)
def split_into_sentences(text):
sentences = list(nlp(text).sents)
sentences = [sentence.text for sentence in sentences]
return sentences
Let’s see how this sentence boundary detection works with a tricky sentence.
>>> text = r"Contractor shall allow public access to all documents, papers, letters or other materials made or received by Contractor in conjunction with this Contract, unless the records are exempt from section 24(a) of Article I of the State Constitution and section 119.07(1), F.S. It is expressly understood that Commerce may unilaterally cancel this Contract for Contractor’s refusal to comply with this provision."
>>> [text.text for text in list(nlp(text).sents)]
['Contractor shall allow public access to all documents, papers, letters or other materials made or received by Contractor in conjunction with this Contract, unless the records are exempt from section 24(a) of Article I of the State Constitution and section 119.07(1), F.S. ', 'It is expressly understood that Commerce may unilaterally cancel this Contract for Contractor’s refusal to comply with this provision.']
Great! Another major thing that I do is to strip out accidentally captured metadata, say if “Page 1 of 24” or “Page 2 of 24” or headers/footers accidentally make their way into the word document. In order to do this, I check if pairwise cosine similarities are above a particular threshold:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_distances
from itertools import combinations
def remove_repetitive_strings(strings, similarity_threshold=80):
# Use TF-IDF Vectorizer to embed each line as a vector
vectorizer = TfidfVectorizer()
vectors = vectorizer.fit_transform(strings)
# Calculate pairwise cosine distances
distances = cosine_distances(vectors)
# Identify pairs with similarity below the threshold
filtered_pairs = [(strings[i], strings[j]) for i, j in combinations(range(len(strings)), 2)
if distances[i, j] <= (1 - similarity_threshold / 100)]
# Create a set of strings to keep track of unique strings
unique_strings = set()
# Add strings from filtered pairs to unique_strings
for string1, string2 in filtered_pairs:
unique_strings.add(string1)
unique_strings.add(string2)
# Filter the original list based on unique strings
filtered_strings = [string for string in strings if string not in unique_strings]
return filtered_strings
There are some other basic things that I do, like stripping whitespaces or removing strings that don’t have any alphabetical content.
Front-end to back-end: an overall structure
The overall structure of the app is as follows: I run one instance of a ReactJS front-end on Vercel, and another instance of a Flask back-end on Azure. The Flask back-end handles all the text pre-processing and inference, whereas the ReactJS front-end passes text entered by the user on the web through a POST request towards the back-end, and receives the text.
There are many little details in the React app - like hovering over a clause text also highlights all other clause texts in the same category, and hovering over the category text at the top also highlights clause texts in that category, as well as other features like dynamically re-sizing text boxes. But there was one feature that was an interesting problem took a while to handle, and that was streaming.
You see, even the best large language models like ChatGPT take a while to complete inference. From the user’s standpoint, the total time of inference might be unacceptable, but they might be more willing to accept small chunks of the text being additively processed.
Therefore, in my use case, instead of waiting for the whole list of texts to be processed, I stream information back from the Flask app at every step of the loop using yield, and React is capable of processing this data while awaiting the final set of data:
# Flask side:
@app.route('/api/generate_clause', methods=['POST'])
def generate_clause():
def generate():
progress=0
clause_text = request.json['text']
lotexts = splitter(clause_text) # splits text according to preprocessing rules
"""
Takes in a list of texts (lotexts) that has already been processed. could be a word doc file or just simple text, either way, they just need to be in an orderly list of texts.
"""
lodict = {}
for i,text in enumerate(lotexts):
progress = (i+1)/len(lotexts)
input_prompt = f"""
Identify the clause type. If it is none, reply None
Text: "{text}"
Contractual clause types:
1. Termination condition
2. Performance obligation
3. Payment obligation
4. Compliance obligation
5. Support or warranty obligation
6. Insurance obligation
7. Quality obligation
8. Audit obligation
9. Contract renewal
10. Exceptions or exemptions
11. Indemnification
12. Jurisdiction
"""
# error handling for model initialization
output = query({
"inputs": input_prompt, "wait_for_model": True
})
category = output[0]['generated_text']
lodict.update({text: category})
json_data = {'text_block': clause_text, 'lookup': lodict, 'progress': progress}
json_string = json.dumps(json_data)
yield json_string # Add a special object to separate JSON streamed objects
# Use Response and stream_with_context to handle streaming
return Response(stream_with_context(generate()),
content_type='text/event-stream')
# React side
const [outputText, setOutputText] = useState(''); //text just gives the raw block text to be displayed
const [outputData, setOutputData] = useState({})
const [progress, setProgress] = useState(0);
const handleGenerateClick = async () => {
const url = 'http://localhost:5000/api/generate_clause';
var tmpPromptResponse = '';
try {
const response = await fetch(url , {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
text: inputText,
}),
});
// eslint-disable-next-line no-undef
let decoder = new TextDecoderStream();
if (!response.body) return;
const reader = response.body
.pipeThrough(decoder)
.getReader();
while (true) {
var {value, done} = await reader.read();
if (done) {
break;
} else {
tmpPromptResponse = JSON.parse(value);
setOutputText(tmpPromptResponse.text_block);
setOutputData(tmpPromptResponse.lookup);
setProgress(tmpPromptResponse.progress);
}
}
} catch (error) {
console.log(error);
}
};
A final word
That brings me to the end of this blog! I think it was a pretty good experience handling everything from conceptualization, down to specifying a language model, testing and finetuning it, before setting it up on a Flask back-end server, and then designing a very preliminary user experience on a React front-end, and finally considering how all the moving pieces would fit together within a production context with compute power constraints (well, in my case, severe constraints).
A lot of what I could do on my GPU-enabled development run was rendered moot because of production constraints: for example, without using Hugging Face’s inference API, I could categorize documents a lot faster and without any limits. But we all must accept and work within the constraints provided, and this project was very instructive in the details of the full production process of an AI project.