Intro to IRT¶
In this notebook, we will give you an intro to IRT. We will cover:
- The data format py-irt uses
- Using off the shelf models from py-irt
- Implementing your own model
In [ ]:
import sys
import numpy as np
from pydantic import BaseModel
# Add py-irt if using git cloned version
# sys.path.append("/data/home/par/code/py-irt")
Assumptions¶
- Subjects
- Subjects will have a randomly drawn skill
- 10 Subjects
- Items
- Each item has a difficulty that we will randomly draw
- Each item will be valid at
validity_rate
rate, and invalid otherwise - 1,000 Items
- Responses (correctness of subject responses to items)
- If the item is valid: Determine randomly according to IRT equation
- If invalid: flip a fair coin
In [ ]:
class Item(BaseModel):
valid: bool
difficulty: float
item_id: str
category: str = 'all'
class Subject(BaseModel):
subject_id: str
skill: float
min_diff = -4
max_diff = 4
validity_rate = .9
def create_item(item_id: str, category: str):
validity = np.random.uniform()
if validity > validity_rate:
valid = 0
else:
valid = 1
return Item(item_id=item_id, valid=valid, difficulty=np.random.uniform(low=min_diff, high=max_diff), category=category)
max_skill = 4
min_skill = -4
def create_subject(subject_id: str):
return Subject(subject_id=subject_id, skill=np.random.uniform(low=min_skill, high=max_skill))
In [ ]:
items = [create_item(f'item_{idx}', 'all') for idx in range(1_000)]
subjects = [create_subject(f'subject_{idx}') for idx in range(10)]
Exporting to py-irt Jsonlines format¶
So that we can use the built-in py-irt data loader, we will write this dataset to jsonlines. The format is:
- Valid Line:
"subject_id": "<subject_id>", "responses": {"<item_id>": <response>}}
- Responses are 1 or 0.
- Note, responses is a dictionary and it does not need to be complete.
In [ ]:
# We'll do a quick export to json to make reading the dataset easier. From the docs:
# Each row looks like this:
# {"subject_id": "<subject_id>", "responses": {"<item_id>": <response>}}
import random
from py_irt.io import write_jsonlines
from py_irt.dataset import Dataset
def write_irt_dataset(subjects: list[Subject], items: list[Item], path: str):
print("Writing dataset with")
print("N Subjects", len(subjects))
print("N Items", len(items))
rows = []
score_by_subject = {}
lookup = {}
for subject in subjects:
responses = {}
correct = 0
total = 0
for item in items:
if item.valid:
responses[item.item_id] = int(1 / (1 + np.exp(-(subject.skill - item.difficulty))) > random.random())
else:
responses[item.item_id] = int(random.random() > .5)
correct += responses[item.item_id]
total += 1
score_by_subject[subject.subject_id] = correct / total
lookup[subject.subject_id] = responses
rows.append({"subject_id": subject.subject_id, "responses": responses})
write_jsonlines(path, rows)
return score_by_subject, lookup
score_by_subject, subject_responses = write_irt_dataset(subjects, items, '/tmp/irt_dataset_section_2.jsonlines')
Writing dataset with N Subjects 10 N Items 1000
Reading data in py-irt and Training a model¶
Next, we will:
- Read this dataset in py-irt
- Create a basic config
- Train a simple 1PL model
Afterwards, we will:
- Dig into the py-irt model definition and adapt it to work for 2PL
In [ ]:
from py_irt.dataset import Dataset
dataset = Dataset.from_jsonlines("/tmp/irt_dataset_section_2.jsonlines")
[18:04:58] amortized: False dataset.py:112
In [ ]:
from py_irt.config import IrtConfig
from py_irt.training import IrtModelTrainer
config = IrtConfig(model_type='1pl', log_every=500, dropout=.2)
trainer = IrtModelTrainer(config=config, data_path=None, dataset=dataset)
trainer.train(epochs=5000, device='cuda')
[16:43:19] amortized: False dataset.py:112
[16:43:19] Vocab size: None training.py:90
args: {'device': 'cuda', 'num_items': 1000, 'num_subjects': 10} training.py:134
Parsed Model Args: {'device': 'cuda', 'num_items': 1000, 'num_subjects': 10, 'priors': training.py:147 'vague', 'dropout': 0.2, 'hidden': 100, 'vocab_size': None}
Output()
Training Pyro IRT Model for 5000 epochs
torch.Size([10000]) torch.Size([10000])
In [ ]:
# Next we can compare: Actual skill, inferred skill, and accuracy
for subject, skill, acc in sorted(list(zip(subjects, trainer.last_params['ability'], score_by_subject.values())), key=lambda v: v[0].skill):
print(subject.subject_id, "Real Skill", subject.skill, "Inferred Skill", skill, "Acc", acc)
subject_8 Real Skill -3.884467092477161 Inferred Skill -3.346115827560425 Acc 0.127 subject_3 Real Skill -2.4092903253582962 Inferred Skill -2.1541693210601807 Acc 0.226 subject_2 Real Skill -2.1774244939456846 Inferred Skill -1.9487590789794922 Acc 0.26 subject_5 Real Skill -2.1411686953088873 Inferred Skill -1.7273616790771484 Acc 0.288 subject_0 Real Skill -1.2458803146590771 Inferred Skill -0.6864756345748901 Acc 0.387 subject_9 Real Skill 0.08445832670945386 Inferred Skill 0.10757223516702652 Acc 0.506 subject_4 Real Skill 0.6030642924177165 Inferred Skill 0.5390836000442505 Acc 0.556 subject_1 Real Skill 2.0165246707222027 Inferred Skill 1.9551409482955933 Acc 0.708 subject_6 Real Skill 3.2611789714609936 Inferred Skill 3.055710554122925 Acc 0.816 subject_7 Real Skill 3.3842764622290247 Inferred Skill 3.1915154457092285 Acc 0.822
In [ ]:
for item, diff in sorted(list(zip(items, trainer.last_params["diff"]))[:20], key=lambda v: v[0].difficulty):
print(item.difficulty, diff)
-3.990418584245244 -3.9722235202789307 -2.6011726771435217 -0.31765908002853394 -2.0620805079999878 -1.6781026124954224 -2.0136971311869223 -3.920189380645752 -1.3044214045970879 -0.9731835126876831 -1.2403805966596124 -1.7017406225204468 -0.4605605330027691 -0.06663890182971954 -0.22063079854633472 -0.09359151124954224 0.1587640996540305 0.40942835807800293 0.5606946449544852 0.2367488145828247 1.0069018644358296 -1.8538750410079956 1.1559239117148525 -1.6636019945144653 1.1776646088154727 1.5613257884979248 1.6586750877950083 2.6625564098358154 2.2038272056877064 1.1499176025390625 2.265316853736441 -0.1889083981513977 2.872176916289317 3.8063602447509766 3.86651072108923 3.609407663345337 3.895528635928546 2.5975146293640137 3.9071289046765534 32.72386169433594
Defining and modifying our own model¶
In the previous example, we used a off-the-shelf model, but it is not too difficult to modify an existing model. In general, model definitions will have:
- A model: This represents our generative story, i.e P(Z|X)
- A guide: This represents the approximating distribution, Q(Z)
- Helper functions (e.g., fit, export, etc)
If we look at the non-commented code, we will see:
observe_data
: represents the IRT equation- There are two levels of sampling statements:
- The draw of IRT parameters from a distribution
- The draw of parameters of that distribution
So to add a new variable, for example discriminability, we will need to:
- In the model, add sampling statements for the priors
- In the model, add sampling statement for discrimability
- Incorporate the variable in the
observe_data
statement - In the guide, make parallel changes.
As for what distribution to use, we'll make a slight change and keep discriminability non-negative by instead of drawing disc~Normal
, we'll have log disc ~ Normal
which is equivalent to disc ~ LogNormal
.
In [ ]:
# Copied from tutorial_section_2_start.py
from py_irt.models import abstract_model
import pyro
import pyro.distributions as dist
import torch
import torch.distributions.constraints as constraints
from pyro.infer import SVI, Trace_ELBO, EmpiricalMarginal
from pyro.optim import Adam
import pandas as pd
from functools import partial
import numpy as np
class TutorialSection2Model(abstract_model.IrtModel):
"""1PL IRT model"""
def __init__(
self,
*,
priors: str,
num_items: int,
num_subjects: int,
verbose: bool = False,
device: str = "cpu",
vocab_size: int = None,
dropout: float = None,
hidden: int = None,
**kwargs,
):
super().__init__(
device=device,
num_items=num_items,
num_subjects=num_subjects,
verbose=verbose,
)
def model_hierarchical(self, models, items, obs):
mu_b = pyro.sample(
"mu_b",
dist.Normal(
torch.tensor(0.0, device=self.device),
torch.tensor(1.0e6, device=self.device),
),
)
u_b = pyro.sample(
"u_b",
dist.Gamma(
torch.tensor(1.0, device=self.device),
torch.tensor(1.0, device=self.device),
),
)
mu_theta = pyro.sample(
"mu_theta",
dist.Normal(
torch.tensor(0.0, device=self.device),
torch.tensor(1.0e6, device=self.device),
),
)
u_theta = pyro.sample(
"u_theta",
dist.Gamma(
torch.tensor(1.0, device=self.device),
torch.tensor(1.0, device=self.device),
),
)
######################################################################
# START: NEW CODE
mu_gamma = pyro.sample(
"mu_gamma",
dist.Normal(
torch.tensor(0.0, device=self.device),
torch.tensor(1.0e6, device=self.device)
)
)
u_gamma = pyro.sample(
"u_gamma",
dist.Gamma(
torch.tensor(1.0, device=self.device),
torch.tensor(1.0, device=self.device),
)
)
with pyro.plate("gammas", self.num_items, device=self.device):
disc = pyro.sample("gamma", dist.LogNormal(mu_gamma, 1 / u_gamma))
# END: NEW CODE
######################################################################
with pyro.plate("thetas", self.num_subjects, device=self.device):
ability = pyro.sample("theta", dist.Normal(mu_theta, 1.0 / u_theta))
with pyro.plate("bs", self.num_items, device=self.device):
diff = pyro.sample("b", dist.Normal(mu_b, 1.0 / u_b))
with pyro.plate("observe_data", obs.size(0)):
pyro.sample(
"obs",
################# New Sampling statement ##################
dist.Bernoulli(logits=disc[items] * (ability[models] - diff[items])), obs=obs
################## Original statement ####################
#dist.Bernoulli(logits=ability[models] - diff[items]), obs=obs
)
def guide_hierarchical(self, models, items, obs):
loc_mu_b_param = pyro.param("loc_mu_b", torch.tensor(0.0, device=self.device))
scale_mu_b_param = pyro.param(
"scale_mu_b",
torch.tensor(1.0e2, device=self.device),
constraint=constraints.positive,
)
loc_mu_theta_param = pyro.param(
"loc_mu_theta", torch.tensor(0.0, device=self.device)
)
scale_mu_theta_param = pyro.param(
"scale_mu_theta",
torch.tensor(1.0e2, device=self.device),
constraint=constraints.positive,
)
alpha_b_param = pyro.param(
"alpha_b",
torch.tensor(1.0, device=self.device),
constraint=constraints.positive,
)
beta_b_param = pyro.param(
"beta_b",
torch.tensor(1.0, device=self.device),
constraint=constraints.positive,
)
alpha_theta_param = pyro.param(
"alpha_theta",
torch.tensor(1.0, device=self.device),
constraint=constraints.positive,
)
beta_theta_param = pyro.param(
"beta_theta",
torch.tensor(1.0, device=self.device),
constraint=constraints.positive,
)
m_theta_param = pyro.param(
"loc_ability", torch.zeros(self.num_subjects, device=self.device)
)
s_theta_param = pyro.param(
"scale_ability",
torch.ones(self.num_subjects, device=self.device),
constraint=constraints.positive,
)
m_b_param = pyro.param(
"loc_diff", torch.zeros(self.num_items, device=self.device)
)
s_b_param = pyro.param(
"scale_diff",
torch.ones(self.num_items, device=self.device),
constraint=constraints.positive,
)
######### New parameters for disc ################
loc_mu_gamma_param = pyro.param(
"loc_mu_gamma", torch.tensor(0.0, device=self.device)
)
scale_mu_gamma_param = pyro.param(
"scale_mu_gamma",
torch.tensor(1.0e2, device=self.device),
constraint=constraints.positive,
)
alpha_gamma_param = pyro.param(
"alpha_gamma",
torch.tensor(1.0, device=self.device),
constraint=constraints.positive,
)
beta_gamma_param = pyro.param(
"beta_gamma",
torch.tensor(1.0, device=self.device),
constraint=constraints.positive,
)
m_gamma_param = pyro.param("loc_disc", torch.zeros(self.num_items, device=self.device))
s_gamma_param = pyro.param("scale_disc", torch.ones(self.num_items, device=self.device), constraint=constraints.positive)
# sample statements
pyro.sample("mu_b", dist.Normal(loc_mu_b_param, scale_mu_b_param))
pyro.sample("u_b", dist.Gamma(alpha_b_param, beta_b_param))
pyro.sample("mu_theta", dist.Normal(loc_mu_theta_param, scale_mu_theta_param))
pyro.sample("u_theta", dist.Gamma(alpha_theta_param, beta_theta_param))
# New Sampling
# Prior sampling
pyro.sample("mu_gamma", dist.Normal(loc_mu_gamma_param, scale_mu_gamma_param))
pyro.sample("u_gamma", dist.Gamma(alpha_gamma_param, beta_gamma_param))
# Paramter sampling
with pyro.plate("gammas", self.num_items, device=self.device):
pyro.sample("gamma", dist.LogNormal(m_gamma_param, s_gamma_param))
with pyro.plate("thetas", self.num_subjects, device=self.device):
pyro.sample("theta", dist.Normal(m_theta_param, s_theta_param))
with pyro.plate("bs", self.num_items, device=self.device):
pyro.sample("b", dist.Normal(m_b_param, s_b_param))
def get_model(self):
return self.model_hierarchical
def get_guide(self):
return self.guide_hierarchical
def fit(self, models, items, responses, num_epochs):
"""Fit the IRT model with variational inference"""
optim = Adam({"lr": 0.1})
svi = SVI(
self.model_hierarchical,
self.guide_hierarchical,
optim,
loss=Trace_ELBO(),
)
pyro.clear_param_store()
for j in range(num_epochs):
loss = svi.step(models, items, responses)
if j % 100 == 0 and self.verbose:
print("[epoch %04d] loss: %.4f" % (j + 1, loss))
print("[epoch %04d] loss: %.4f" % (j + 1, loss))
values = ["loc_diff", "scale_diff", "loc_ability", "scale_ability"]
def export(self):
return {
"ability": pyro.param("loc_ability").data.tolist(),
"diff": pyro.param("loc_diff").data.tolist(),
"disc": pyro.param("loc_disc").data.tolist(),
}
def predict(self, subjects, items, params_from_file=None):
"""predict p(correct | params) for a specified list of model, item pairs"""
if params_from_file is not None:
model_params = params_from_file
else:
model_params = self.export()
abilities = np.array([model_params["ability"][i] for i in subjects])
diffs = np.array([model_params["diff"][i] for i in items])
discs = np.array([model_params["disc"][i] for i in items])
return 1 / (1 + np.exp(-discs * (abilities - diffs)))
def summary(self, traces, sites):
"""Aggregate marginals for MCM"""
marginal = (
EmpiricalMarginal(traces, sites)
._get_samples_and_weights()[0]
.detach()
.cpu()
.numpy()
)
print(marginal)
site_stats = {}
for i in range(marginal.shape[1]):
site_name = sites[i]
marginal_site = pd.DataFrame(marginal[:, i]).transpose()
describe = partial(
pd.Series.describe, percentiles=[0.05, 0.25, 0.5, 0.75, 0.95]
)
site_stats[site_name] = marginal_site.apply(describe, axis=1)[
["mean", "std", "5%", "25%", "50%", "75%", "95%"]
]
return site_stats
In [ ]:
from py_irt.config import IrtConfig
from py_irt.training import IrtModelTrainer
config = IrtConfig(model_type=TutorialSection2Model, log_every=500, dropout=.2)
trainer = IrtModelTrainer(config=config, data_path=None, dataset=dataset)
trainer.train(epochs=2000, device='cuda')
[18:15:45] Vocab size: None training.py:93
args: {'device': 'cuda', 'num_items': 1000, 'num_subjects': 10} training.py:143
Parsed Model Args: {'device': 'cuda', 'num_items': 1000, 'num_subjects': 10, 'priors': training.py:156 'vague', 'dropout': 0.2, 'hidden': 100, 'vocab_size': None}
Output()
Training Pyro IRT Model for 2000 epochs
torch.Size([10000]) torch.Size([10000])
In [ ]:
import pyro
pyro.render_model(trainer._pyro_model, model_args=(
torch.zeros(1000, dtype=torch.long, device='cuda'),
torch.zeros(1000, dtype=torch.long, device='cuda'),
torch.zeros(1000, dtype=torch.float, device='cuda'),
))
Out[ ]:
In [ ]:
# Next we can compare: Actual skill, inferred skill, and accuracy
for subject, skill, acc in sorted(list(zip(subjects, trainer.last_params['ability'], score_by_subject.values())), key=lambda v: v[0].skill):
print(subject.subject_id, "Real Skill", subject.skill, "Inferred Skill", skill, "Acc", acc)
subject_8 Real Skill -2.0992039470170196 Inferred Skill -4.6167683601379395 Acc 0.299 subject_0 Real Skill -1.1475070994498333 Inferred Skill -0.4494936466217041 Acc 0.381 subject_9 Real Skill -0.820635677327914 Inferred Skill 0.22900916635990143 Acc 0.417 subject_3 Real Skill 0.1286691849779622 Inferred Skill 1.454652190208435 Acc 0.517 subject_5 Real Skill 0.4896205821266699 Inferred Skill 1.7657090425491333 Acc 0.543 subject_2 Real Skill 0.6767643123720228 Inferred Skill 1.977298617362976 Acc 0.582 subject_6 Real Skill 0.7634125176672466 Inferred Skill 1.9012914896011353 Acc 0.573 subject_4 Real Skill 1.768818348571382 Inferred Skill 4.846320629119873 Acc 0.675 subject_1 Real Skill 2.925906218662666 Inferred Skill 7.383442401885986 Acc 0.79 subject_7 Real Skill 3.6667097508179456 Inferred Skill 12.50934886932373 Acc 0.86
In [ ]:
for item, diff, disc in sorted(list(zip(items, trainer.last_params["diff"], trainer.last_params['disc']))[:20], key=lambda v: v[0].difficulty):
print(item.difficulty, diff, disc)
-3.8217550888270777 -3.5826640129089355 2.1924033164978027 -3.69833877591533 -9.665823936462402 0.6793660521507263 -3.4538712920910974 -9.909111022949219 -0.14002259075641632 -3.4031693619313703 -9.54269027709961 0.08277589827775955 -3.3821151826150695 -3.430758476257324 -1.6421877145767212 -3.3417983185345825 -10.073620796203613 0.8655638694763184 -3.2896540965974825 -9.245588302612305 1.2930814027786255 -3.1218402625983757 -4.012744426727295 0.9936566948890686 -2.0961132414719534 -6.080544948577881 -1.4029154777526855 -0.5853620860445243 0.8478484749794006 0.3265891969203949 0.1524570163147061 1.7222025394439697 -0.3942753076553345 0.36958442550909076 0.6623020768165588 0.25205308198928833 0.5101799299201764 4.00017786026001 -2.9407012462615967 0.9451040264956427 8.412286758422852 -2.2563652992248535 1.1586789766540804 1.7541354894638062 -0.1581266075372696 1.7491729978605326 7.629848957061768 -1.6286585330963135 1.8610770213746637 7.323568820953369 -0.9194203019142151 1.9117725888958557 7.841511249542236 -1.87253999710083 2.0980146884621664 -1.5621899366378784 -7.525204658508301 2.9071880117471194 8.110783576965332 0.5085978507995605
In [ ]: