The Code: Chain-of-Thought Prompting in Practice
In this section, we’ll implement chain-of-thought prompting. Since we don’t have access to PaLM (540B), we’ll demonstrate the structure using a smaller open-source model. The key is understanding the prompt structure, not the model size.
Note: This code runs on Google Colab with free resources. For production use with complex reasoning, you’d want a larger model (like GPT-4 or PaLM), but the prompt structure remains identical.
Code 1: Basic Chain-of-Thought Prompt Structure
# Chain-of-Thought Prompting Example
# Runs on Google Colab with transformers library
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load a small open-source model
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Define few-shot examples with chain-of-thought
cot_prompt = """Q: Sarah has 10 apples. She buys 5 more.
Then she gives 3 to her friend. How many apples does she have?
A: Let me work through this step by step.
Starting apples: 10
After buying: 10 + 5 = 15
After giving away: 15 - 3 = 12
Answer: 12 apples
Q: A baker makes 24 cookies and puts them in boxes of 4.
He sells 3 boxes. How many cookies remain?
A: Let me break this down.
Total cookies: 24
Cookies per box: 4
Number of boxes: 24 / 4 = 6 boxes
Boxes sold: 3
Boxes remaining: 6 - 3 = 3 boxes
Cookies remaining: 3 × 4 = 12 cookies
Answer: 12 cookies
Q: A train travels 60 km in 2 hours. How far does
it travel in 5 hours at the same speed?
A: Step by step:
Distance in 2 hours: 60 km
Speed: 60 / 2 = 30 km/hour
Distance in 5 hours: 30 × 5 = 150 km
Answer: 150 km
Q: A store has 100 items. They receive 50 more items.
They sell 30 items. How many items remain?
A:"""
# Tokenize and generate
input_ids = tokenizer.encode(cot_prompt, return_tensors="pt")
output = model.generate(input_ids, max_length=150,
temperature=0.7, do_sample=True)
# Decode and print
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)
What this code does:
- Loads GPT-2 (small, free model)
- Creates a few-shot prompt with three examples, each showing:
- Problem statement (Q)
- Explicit reasoning steps (A with breakdown)
- Generates a response for a new problem
- The model learns from the structure: “problem → reasoning → answer”
Expected output: GPT-2 will try to follow the pattern (though smaller models are less reliable at complex reasoning):
A: Let me think through this:
Starting items: 100
Items received: 50
Total after receiving: 100 + 50 = 150
Items sold: 30
Items remaining: 150 - 30 = 120
Answer: 120 items
Code 2: Standard vs. Chain-of-Thought Comparison
# Compare Standard Prompting vs Chain-of-Thought
# This demonstrates the difference in prompt structure
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# STANDARD PROMPTING (just Q and A, no reasoning)
standard_prompt = """Q: Sarah has 10 apples. She buys 5 more.
Then she gives 3 to her friend. How many?
A: 12
Q: A baker makes 24 cookies, puts them in boxes of 4.
He sells 3 boxes. How many cookies remain?
A: 12
Q: A train travels 60 km in 2 hours.
How far in 5 hours?
A: 150
Q: A store has 100 items, receives 50 more, sells 30.
How many remain?
A:"""
# CHAIN-OF-THOUGHT PROMPTING (Q, reasoning, A)
cot_prompt_alt = """Q: Sarah has 10 apples. She buys 5 more.
Then she gives 3 to her friend. How many?
A: Starting: 10. After buying: 10 + 5 = 15.
After giving: 15 - 3 = 12. Answer: 12
Q: A baker makes 24 cookies, puts them in boxes of 4.
He sells 3 boxes. How many cookies remain?
A: Total cookies: 24. Per box: 4. Number of boxes: 24/4 = 6.
Sold: 3. Remaining boxes: 6 - 3 = 3.
Remaining cookies: 3 × 4 = 12. Answer: 12
Q: A train travels 60 km in 2 hours. How far in 5 hours?
A: Speed: 60/2 = 30 km/hour. Distance: 30 × 5 = 150 km.
Answer: 150
Q: A store has 100 items, receives 50 more, sells 30.
How many remain?
A:"""
print("=" * 50)
print("STANDARD PROMPTING OUTPUT:")
print("=" * 50)
input_ids = tokenizer.encode(standard_prompt, return_tensors="pt")
output = model.generate(input_ids, max_length=80, temperature=0.5)
print(tokenizer.decode(output[0], skip_special_tokens=True)[-100:])
print("\n" + "=" * 50)
print("CHAIN-OF-THOUGHT PROMPTING OUTPUT:")
print("=" * 50)
input_ids = tokenizer.encode(cot_prompt_alt, return_tensors="pt")
output = model.generate(input_ids, max_length=100, temperature=0.5)
print(tokenizer.decode(output[0], skip_special_tokens=True)[-120:])
What this demonstrates:
- Standard prompting gives Q→A directly
- CoT prompting includes reasoning in the answer
- With larger models (100B+), CoT dramatically improves accuracy
- GPT-2 is too small to show the difference clearly, but the prompt structure is identical to what PaLM uses
Key Takeaways from the Code
- Prompt structure is everything: The model learns the pattern from few-shot examples
- Reasoning must be explicit: Include calculation steps, not just final answers
- Scaling is critical: The paper’s power comes from model size (PaLM 540B), not from the prompting technique itself
- Same technique, bigger benefits: A 1.3B model with CoT beats a 175B model without CoT (InstructGPT vs GPT-3)
Why Run This on Colab?
GPT-2 (1.5B parameters) is free to run on Colab. It won’t solve math perfectly (too small), but it shows:
- How to structure CoT prompts
- How the model processes examples
- How to extract generated text
For real reasoning tasks, you’d use:
- Claude (Anthropic)
- GPT-4 (OpenAI)
- Llama 2 Chat (Meta)
- PaLM 2 (Google) — the actual model from this paper
All of these support the same CoT prompt structure shown above.
Pro Tip: Self-Consistency (Follow-Up)
The paper’s authors later published a follow-up called “Self-Consistency Improves Chain of Thought Reasoning in Language Models” (Wang et al., 2022). The idea:
Instead of generating one reasoning chain, generate multiple different chains and take a majority vote:
# Pseudo-code for self-consistency
for i in range(5): # Generate 5 reasoning chains
chain = model.generate(prompt, temperature=1.0) # High temp = diversity
answer = extract_final_answer(chain)
answers.append(answer)
final_answer = majority_vote(answers) # Most common answer
This pushes accuracy even higher (e.g., 58% → 71% on GSM8K).