% lab 8: Bayes Rule and Bayesian updates to probability
clear
close all



%% 1 - conditional and marginal probability (prosecutor fallacy)
% see slides.
sensitivity = 0.9;              % probability of positive test given disease
specificity = 0.8;              % probability of negative test given no disease
prevalence = 0.02;
n = 100000;                     % number of test subjects

%% Compute the full table of expected numbers of subjects, both 
% marginals (e.g., proportion and number who test negative) and joint 
% (e.g., proportion and number of subjects who test negative and 
% don't have Covid)

m = [prevalence*(1-sensitivity) (1-prevalence)*specificity ; prevalence*sensitivity (1-prevalence)*(1-specificity) ; prevalence (1-prevalence)];
m(:,3) = sum(m,2);
disp(m);
mm = n*m;
disp(100000*mm)

%% Compute P(infected | test positive) and p(not infected | test negative)
% two different ways (printed as percentages):

% First method: Use the numbers in the table you computed above, directly
disp(100*mm(2,1)/mm(2,3));
disp(100*mm(1,2)/mm(1,3));

% Second method: Use Bayes Rule (you'll need to find the denominator in
% your table above)

disp(100*sensitivity*prevalence/m(2,3));
disp(100*specificity*(1-prevalence)/m(1,3));

%% 2 - binomial
% Suppose we collected N coin flips of which proportion p came up head
% (or had the letter 'e', or whatever)
N = 10;
p = 4/10;
% If that proportion is ground truth, what is the mean, SD and variance of
% the count of heads in N coin flips?

% Method 1: Use makedist(), mean(), std(), var()
Bino = makedist('Binomial','n',N,'p',p);
fprintf(1,'mean=%.2f, SD = %.2f var = %.2f\n',mean(Bino),std(Bino),var(Bino));

% Method 2: compute the full binomial distribution and derive the 
% statistics from that. You may use nchoosek()
values = 0:N;           % the possible numbers of heads
% Loop because nchoosek() won't take a vector argument
for i = 1:(N+1)
    probs(i) = nchoosek(N,values(i))*p^values(i)*(1-p)^(N-values(i));
end
mn = dot(values,probs);
var = dot((values-mn).^2,probs);
sd = sqrt(var);
fprintf(1,'mean=%.2f, SD = %.2f var = %.2f\n',mn,sd,var);

% Method 3: Compute the values for a single coin flip, then generalize
% since this is the sum of N independent coin flips
bernmean = p;
bernvar = p*(1-p);
mn = N*bernmean;
var = N*bernvar;
sd = sqrt(var);
fprintf(1,'mean=%.2f, SD = %.2f var = %.2f\n',mn,sd,var);


%% 3 - Bayesian updating: Beta and Bernoulli distribution
% We start with a prior expectation of the probability of our Bernoulli
% random variable (P(heads) or P("e" in a name)
% We need a probability distribution over values of P(head), i.e., the 
% distribution ranges over the interval [0,1]. The "beta" distribution
% is the standard distribution used for this. It has two parameters, which
% behave as if you had already collected a series of observations of heads
% and tails. beta(1,1) is a flat, uniform distribution over [0,1], 
% beta(a+1,b+1) acts as if you have already tossed the coin a+b times and 
% observed 'a' heads and 'b' tails
x = linspace(0,1,101);          % calculate the distribution over x = 0,.01,.02,...,1
a = random(Bino);               % pick a random initial number of heads
b = N-a;                        % the corresponding number of tails
beta1 = betapdf(x,a+1,b+1);     % compute the beta "prior" distribution
% Now, plot the prior distribution and compute it's maximum (i.e., the 
% most likely estimate of the probability of heads before collecting the
% new sample data
figure;
plot(x,beta1)
hold on
% find the mode
[~,idx] = max(beta1);
sprintf('most likely rate is %.3f',x(idx))

% Next, collect new sample data
a2 = random(Bino);
b2 = N-a;

% Next, compute the posterior distribution, combining the new sample
% with the assumed sample that gave rise to the prior distribution

% Method 1: Just update the parameters of the beta distribution to 
% add in the new sample values
% Compute that posterior, plot it in the same figure as the prior and
% compute the x-value that yields its maximum (the "Maximum a Posteriori"
% or MAP estimate
atotal = a2 + a;
btotal = b2 + b;
beta2 = betapdf(x,atotal+1,btotal+1);
plot(x,beta2,'LineWidth',2)
hold on
% find the mode
[~,idx] = max(beta2);
sprintf('most likely rate now becomes %.3f',x(idx))

% Method 2 - use Bayes Rule
% compute P(x|data)=P(data|x)P(x), normalize it to have the same
% peak as the prior (for easy visual comparison), and compute the value
% of x for which posterior probability is maximal
prior = beta1;
likelihood = x.^a2 .* (1-x).^b2;
posterior = prior.*likelihood;
[~,idx] = max(posterior);
sprintf('manually get most likely rate: %.3f',x(idx))

magnifier = max(beta2)/max(posterior);
plot(x,posterior*magnifier,'--','LineWidth',2)

%% 4 - Bayesian updating: Normal distribution
%
% The previous example showed that the "conjugate" distribution to the
% Binomial is the Beta. That is, if you have a Beta prior distribution
% over the coin's probability 'p' and combine with binomial data (a set
% of coin tosses governed by P(tails)=p), the posterior distribution 
% remains a beta
%
% Here, we'll qualitatively check that the conjugate distribution for a
% normal distribution (unknown mean, but known variance sigma^2) is the
% normal distribution

% Pick a mean and standard deviation for your prior distribution
priormean = 0;
priorvar = 1;

% Pick a set of parameters from which you will draw your samples (number
% of samples and mean). You might as well use the common variance, since
% that's what we are assuming
N = 20;
datamean = 3;

% Draw a sample
priorsd = sqrt(priorvar);
sample = datamean + priorsd*randn(N,1);

% Compute the prior over a fine, discrete set of values that is wide
% enough to encompass all samples and most of the probability mass
x = -5:.01:7;
nsample = length(x);
prior = (1/(sqrt(2*pi)*priorsd))*exp(-(x-priormean).^2/(2*priorvar));

% Compute the posterior using Bayes Rule, normalizing both prior and 
% posterior to sum to one over the discrete points
deltaxs = (sample * ones(1,nsample)) - (ones(N,1)*x);
probvalues = (1/(sqrt(2*pi)*priorsd))*exp(-deltaxs.^2/(2*priorvar));
likelihood = prod(probvalues);
posterior = likelihood.*prior;

prior = prior/sum(prior);
posterior = posterior/sum(posterior);

% Plot prior and posterior

figure;
plot(x,prior,'-');
hold on;
plot(x,posterior,'--');

% Play with this, examining what happens as the sample deviates more and
% more from the mean of the prior, and as the sample size increases
