# Probabilistic algorithms on bit vectors

Require Import Arith.

Require Export BVlib.

Require Export DistrTactic.

Require Export IsDiscrete.

Ltac si := repeat simpl_mu_rewrite idtac.

Set Implicit Arguments.

# The space of bit vectors is discrete and decidable

Fixpoint bvpoints n code :=
match n return Bvector n with
| 0 => Vnil
| S m => let (i,j) := bij_n_nxn code in Vcons bool (points i) m (bvpoints m j)
end.

Instance bvector_discrete : forall n, Discrete_domain (Bvector n).

Proof.

intro.

apply Build_Discrete_domain with (points := bvpoints n).

induction n; intros.

simpl.
rewrite (V0_eq _ x). apply ex_intro with (x:=0%nat); reflexivity.
VS_rewrite x.

simpl.

destruct (points_surj x_hd) as (hd,Hhd).

destruct (IHn x_tl) as (tl,Htl).

destruct (bij_surj hd tl) as (code,Hn).

exists code.

rewrite Hn; rewrite Hhd; rewrite Htl; reflexivity.

Qed.

Instance bvector_decidable : forall n, DecidEq (Bvector n) :=
fun n => Build_DecidEq (@Bvector_eq_dec n).

# Distributions on bit vectors

## Uniformly choose a vector of length n

Fixpoint BVrandom n :=
match n return distr (Bvector n) with
| 0 => Munit Bnil
| S m => Mlet Flip (fun hd => Mlet (BVrandom m) (fun tl => Munit (Vcons hd tl)))
end.

Theorem BVrandom_eq : forall n (x : Bvector n),
mu (BVrandom n) (fun b => B2U (BVeq b x)) == [1/2]^n.

Proof.

induction x.

reflexivity.

simpl BVrandom; si.

case a; rewrite IHx; si; auto.

Qed.

Theorem BVxor_noise_si : forall n k f,
mu (BVrandom n) (fun m => f (BVxor n k m)) == mu (BVrandom n) f.

Proof.

intros n k f.

induction k.

reflexivity.

simpl BVrandom; si.

setoid_rewrite <- IHk at 3 4.

case a; auto.

Qed.

Theorem BVxor_noise : forall n k,
Mlet (BVrandom n) (fun m => Munit (BVxor n k m)) == BVrandom n.

Proof.

intros n k f; si; apply BVxor_noise_si.

Qed.

## Uniformly choose a vector of length n with k bits sets to 1

Fixpoint BVrandom_k n k :=
match n return distr (Bvector n) with
| 0 => Munit Bnil
| S m =>

Mchoice (k */ [1/]1+m)
(Mlet (BVrandom_k m (k-1)%nat) (fun tl => Munit (Vcons true tl)))
(Mlet (BVrandom_k m k) (fun tl => Munit (Vcons false tl)))
end.

Theorem BVrandom_k_range :
forall n k,
(k <= n)%nat ->
range (fun x => @occurrences n x true = k) (BVrandom_k n k).

Proof.

unfold range.

intro n; induction n; intros.

destruct k; [|elim (le_Sn_O _ H)].

simpl.

apply H0.

reflexivity.

simpl BVrandom_k; si.

assert ((0 < k)%nat ->
0%U == mu (BVrandom_k n (k-1)) (fun x : Bvector n => f (Vcons true x)))
as HX.

intros.

apply IHn; try omega; intros.

apply H0; simpl; rewrite H2; omega.

assert ((k < S n)%nat ->
0 == mu (BVrandom_k n k) (fun x : Bvector n => f (Vcons false x)))
as HY.

intros.

apply IHn; try omega; intros.

apply H0; simpl; rewrite H2; omega.

assert (k = S n \/ k < S n)%nat as [H3|H3] by omega.

rewrite <- HX by omega.

rewrite H3.

setoid_replace (S n */ [1/]1+n) with 1 by auto.

repeat Usimpl; auto.

rewrite <- (HY H3); repeat Usimpl.

assert (k > 0 \/ k = 0)%nat as [H4|H4] by omega.

rewrite <- (HX H4); auto.

rewrite H4; auto.

Qed.

Theorem BVrandom_k_range_dec : forall n k, dec (fun x => @occurrences n x true = k).

Proof.

intros; intro; auto with arith.

Qed.

BVrandom_k always terminates and returns a vector of n bits with k occurrences of true.

Theorem BVrandom_k_correct : forall n k, k <= n ->
mu (BVrandom_k n k) (fun w => B2U (beq_nat (@occurrences n w true) k)) == 1.

Proof.

intro n.

induction n; intros.

assert (k=0)%nat as Hk.

auto with arith.

rewrite Hk.

auto.

simpl BVrandom_k; si.

unfold carac; simpl occurrences.

destruct k.

repeat Usimpl.

apply IHn.

auto with arith.

simpl pred.

elim (Compare.le_decide k n (le_S_n _ _ H)); intro.

unfold carac in IHn.

rewrite <- pred_of_minus; simpl pred.

rewrite (IHn k).

rewrite (IHn (S k)).

auto.

auto with arith.

auto with arith.

rewrite b.

rewrite (Nmult_ge_Sn_Unth (le_n (S n))).

repeat Usimpl.

rewrite <- pred_of_minus; simpl pred.

apply IHn.

auto.

Qed.

Theorem O_lt_mult : forall n m, (0 < n -> 0 < m -> 0 < n*m)%nat.

Proof.

intros.

destruct n; destruct m; try (assert False as F by omega; elim F).

apply lt_0_Sn.

Qed.

Theorem BVrandom_k_eq_min :
forall n w k,
mu (BVrandom_k n k) (fun w' => B2U (@BVeq n n w' w)) ==
(if beq_nat (occurrences w true) (Min.min n k)%nat then
Unth (pred (comb k n))
else 0).

Proof.

induction w; intro k.

si; simpl.

destruct k; auto.

simpl BVrandom_k; si.

assert (forall b (w':Bvector n), B2U (BVeq (Vcons b w') (Vcons a w)) ==
B2U (eqb b a) * B2U (BVeq w' w)) by
(intros [|] w'; destruct a; simpl; repeat Usimpl; reflexivity).

setoid_rewrite H; clear H; rewrite <- 2!fmult_def; rewrite 2!mu_stable_mult.

rewrite 2! IHw.

destruct a; simpl eqb; simpl B2U; repeat Usimpl.

assert (k > S n \/ k <= S n)%nat as [Hlt|Hle] by omega.

rewrite (Min.min_l n (k-1)%nat) by omega.

rewrite (Min.min_l (S n) k) by omega.

simpl occurrences; simpl beq_nat.

case (beq_nat (occurrences w true) n); auto.

rewrite comb_not_le by omega.

rewrite comb_not_le by omega.

simpl; repeat Usimpl.

apply Nmult_ge_Sn_Unth; omega.

rewrite (Min.min_r n (k-1)%nat) by omega.

rewrite (Min.min_r (S n) k) by omega.

simpl occurrences.

destruct k; [auto|].

simpl beq_nat.

replace (k-0)%nat with k by omega.

case (beq_nat (occurrences w true) k); auto.

rewrite <- Nmult_Umult_assoc_left by auto.

replace ([1/]1+n) with ([1/](S n)) by auto.

rewrite Umult_Nnth; [| omega | apply comb_le_0_lt; omega ].

change ([1/] comb (S k) (S n)) with (1 */ [1/] comb (S k) (S n)).

apply Nmult_Unth_eq.

rewrite mult_1_l.

rewrite <- S_pred with (m := O) by (apply comb_le_0_lt; omega).

replace (S k - 1)%nat with k by omega.

rewrite <- S_pred with (m := O) by (apply O_lt_mult; [|apply comb_le_0_lt];omega).

aac_rewrite comb_incr_k; aac_rewrite comb_incr_n; aac_reflexivity.

simpl occurrences.

assert (n < k \/ k <= n)%nat as [Hlt|Hle] by omega.

rewrite (Min.min_l n k) by omega.

rewrite (Min.min_l (S n) k) by omega.

case_eq (beq_nat (occurrences w true) (S n)); intro.

assert (occurrences w true <= n)%nat by apply occurrences_length_le.

rewrite (beq_nat_eq _ _ (eq_sym H)) in H0.

assert False as F by omega; elim F.

setoid_replace (k */ [1/]1+n) with 1; auto.

rewrite (Min.min_r (S n) k) by omega.

rewrite (Min.min_r n k) by omega.

case (beq_nat (occurrences w true) k); [|auto].

rewrite Uinv_Nmult.

rewrite <- Nmult_Umult_assoc_left by (apply Nmult_def_Unth_le; omega).

replace ([1/]1+n) with ([1/](S n)) by auto.

rewrite Umult_Nnth; [| omega | apply comb_le_0_lt; omega ].

change ([1/] comb k (S n)) with (1 */ [1/] comb k (S n)).

apply Nmult_Unth_eq.

rewrite mult_1_l.

rewrite <- S_pred with (m := O) by (apply comb_le_0_lt; omega).

rewrite <- S_pred with (m := O) by (apply O_lt_mult; [|apply comb_le_0_lt]; omega).

aac_rewrite comb_incr_n.

aac_reflexivity.

Qed.

Theorem BVrandom_k_eq :
forall n k, (k <= n)%nat -> forall w,
mu (BVrandom_k n k) (fun w' => B2U (@BVeq n n w' w)) ==
(if beq_nat (occurrences w true) k then
Unth (pred (comb k n))
else 0).

Proof.

intros.

rewrite <- (Min.min_r n k H) at 2.

apply BVrandom_k_eq_min.

Qed.

## Uniformly choose a vector with k bits set to 1 among those set in a mask

Fixpoint BVrandom_k_mask n (k:nat) (mask : Bvector n) : distr (Bvector n) :=
| Vnil => Munit Vnil
| Vcons false _ tl =>
Mlet (BVrandom_k_mask k tl) (fun tl => Munit (Vcons false tl))
| Vcons true _ tl =>
Mlet (bernoulli (Nmult k (Unth (occs_1 tl)))) (fun r =>
if r then
Mlet (BVrandom_k_mask (pred k) tl) (fun tl => Munit (Vcons true tl))
else
Mlet (BVrandom_k_mask k tl) (fun tl => Munit (Vcons false tl)))
end.

Proof.

generalize k f; induction mask; intros.

simpl; apply H.

reflexivity.

simpl; destruct a; si.

rewrite <- (IHmask (fun x : Bvector n => f (Vcons false x)) (pred k0)).

rewrite <- (IHmask (fun x : Bvector n => f (Vcons false x)) k0).

repeat Usimpl; reflexivity.

intros; apply H.

unfold missing, occs_1; simpl; exact H0.

intros; apply H.

unfold missing, occs_1; simpl; exact H0.

apply (IHmask (fun x : Bvector n => f (Vcons false x)) k0).

intros; apply H.

unfold missing, occs_1; simpl; exact H0.

Qed.

Ltac omega_absurd := exfalso; omega.

Ltac S_rewrite x := destruct x; [ omega | idtac ].

forall n (mask : Bvector n),
forall k, (k <= occs_1 mask)%nat ->
forall w, missing w mask = O ->
mu (BVrandom_k_mask k mask) (fun w' => B2U (@BVeq n n w' w)) ==
(if beq_nat (occs_1 w) k then Unth (pred (comb k (occs_1 mask))) else 0).

Proof.

change (occs_1 Vnil) with O in H.

replace k with O by omega.

rewrite (V0_eq _ w).

simpl; repeat Usimpl; reflexivity.

destruct a.

Focus 2.

VS_rewrite w.

destruct w_hd; try discriminate H0.

simpl.

change (occs_1 (Vcons false w_tl)) with (occs_1 w_tl).

case (occs_1 w_tl); try (repeat Usimpl; reflexivity).

VS_rewrite w.

rewrite missing_true in H0.

assert (k = S (occs_1 mask) \/ k <= occs_1 mask)%nat as H_ by omega;
destruct H_ as [Heq|Hle].

rewrite Heq; rewrite Nmult_ge_Sn_Unth by omega; repeat Usimpl.

rewrite <- pred_Sn.

destruct w_hd; simpl BVeq.

repeat rewrite comb_n_n.

reflexivity.

si.

assert (occs_1 w_tl <= occs_1 mask)%nat by (apply missing_occs_le; assumption).

change (occs_1 (Vcons false w_tl)) with (occs_1 w_tl).

rewrite if_beq_nat_nat_eq_dec.

case (eq_nat_dec (occs_1 w_tl) (S (occs_1 mask))); intros; [omega_absurd|reflexivity].

destruct w_hd; simpl BVeq.

rewrite Umult_zero_right.
rewrite Uplus_zero_right.

assert (k = 0 \/ k > 0)%nat as Hk by omega; destruct Hk as [Hk|Hk].

rewrite Hk; simpl; repeat Usimpl; reflexivity.

replace (beq_nat (occs_1 w_tl) (pred k)) with (beq_nat (occs_1 (Vcons true w_tl)) k).

case (beq_nat (occs_1 (Vcons true w_tl)) k); try (repeat Usimpl; reflexivity).

rewrite <- Nmult_Umult_assoc_left by (apply Nmult_def_Unth_le; omega).

rewrite Umult_Nnth;
[ | omega | apply comb_le_0_lt; omega ].

rewrite Nmult_Unth_factor with (m2 := pred (comb k (occs_1 (Vcons true mask)))).

reflexivity.

rewrite <- S_pred with (m:=O) by
(apply comb_le_0_lt; unfold occs_1 in *|-*; simpl; omega).

S_rewrite k.

change (pred (S k)) with k.

aac_rewrite comb_incr_k.

rewrite <- S_pred with (m:=O).

aac_rewrite comb_incr_n.

aac_reflexivity.

assert ((comb k (occs_1 mask) > 0)%nat) by (apply comb_le_0_lt; omega).

simpl; auto with arith.

destruct k; [ inversion Hk | reflexivity ].

rewrite mu_fzero_eq; repeat Usimpl.
change (occs_1 (Vcons false w_tl)) with (occs_1 w_tl).
case (beq_nat (occs_1 w_tl)); try (repeat Usimpl; reflexivity).

rewrite Uinv_Nmult.

rewrite <- Nmult_Umult_assoc_left by (apply Nmult_def_Unth_le; omega).

rewrite Umult_Nnth; [ | omega | apply comb_le_0_lt; omega ].

apply Nmult_Unth_factor with (m2 := pred (comb k (occs_1 (Vcons true mask)))).

rewrite <- pred_Sn.

assert ((comb k (occs_1 mask) > 0)%nat) by (apply comb_le_0_lt; omega).

repeat rewrite <- S_pred with (m:=O).

aac_rewrite comb_incr_n; aac_reflexivity.

simpl; S_rewrite (comb k (occs_1 mask)); simpl; auto with arith.

apply comb_le_0_lt.

Qed.

# Correction of a vector by adding and removing 1-bits

correct k' ki km returns a correction of k' by flipping ki 1-bits and km 0-bits.
Definition correct : forall n, Bvector n -> nat -> nat -> distr (Bvector n) :=
fun n K' ki km =>
Mlet (BVrandom_k_mask ki K') (fun i =>
Mlet (BVrandom_k_mask km (~~ K')) (fun m =>
Munit ((K' && (~~ i)) || m))).

Theorem missing_cover : forall n (mask : Bvector n),
cover (fun x => missing x mask = 0%nat) (fun x => B2U (beq_nat (missing x mask) 0%nat)).

Proof.

split; intro.

rewrite H; reflexivity.

rewrite <- beq_nat_neq by assumption.

reflexivity.

Qed.

Lemma or_eq_missing : forall n (K K' a : Bvector n),
missing a (~~ K') = 0%nat ->
B2U (BVeq (K' || K) (K' || a)) ==
B2U (BVeq a (K && (~~ K'))).

Proof.

induction n; intros;
[ repeat setoid_rewrite V0_eq | VS_rewrite a; VS_rewrite K; VS_rewrite K' ].

rewrite (V0_eq _ a); reflexivity.

simpl.

simpl Bneg in H.

assert (B2U (BVeq (K'_tl || K_tl) (K'_tl || a_tl)) ==
B2U (BVeq a_tl (K_tl && (~~ K'_tl)))).

apply IHn.

apply (missing_0_tl _ _ H).

destruct K_hd; destruct K'_hd; destruct a_hd; simpl; try rewrite H0;
first [ reflexivity | rewrite missing_S in H; discriminate ].

Qed.

Theorem cancel_missing : forall n (K K' : Bvector n),
mu
(Mlet (BVrandom_k_mask (missing K K') (~~ K')) (fun x => Munit (K' || x)))
(fun v => B2U (BVeq (K' || K) v))
== [1/] (comb (missing K K') (n - occs_1 K')).

Proof.

intros.

rewrite <- Bneg_occs_1.

rewrite Mlet_simpl.

setoid_rewrite Munit_simpl.

rewrite (range_eq (BVrandom_k_mask_range (missing K K') (~~ K')))
with (g := (fun x : Bvector n => B2U (BVeq x (K && (~~ K')))))
by apply or_eq_missing.

rewrite <- beq_nat_refl; reflexivity.

unfold missing, occs_1.

rewrite (_ : AAC.Commutative eq (BVand (n:=n))); apply occs_1_BVand.

apply missing_BVand.

Qed.

Proposition Of_equal : forall A (f : A -> U) x y, x = y -> f x == f y.

Proof.

intros; rewrite H; reflexivity.

Qed.

Lemma Umult_B2U_andb : forall x y, (B2U x) * (B2U y) == B2U (andb x y).

Proof.

intros; unfold B2U; simpl.

destruct x; destruct y; simpl; repeat Usimpl; reflexivity.

Qed.

Lemma correct_success_eq : forall n (K K' : Bvector n),
forall x y,
B2U (beq_nat (missing x K') 0) *
(B2U (beq_nat (missing y (~~ K')) 0) *
B2U (BVeq K ((K' && ~~ x) || y)))
== B2U (BVeq x (K' && ~~ K)) *
B2U (BVeq y (K && ~~ K')).

Proof.

intros.

repeat rewrite Umult_B2U_andb.

apply Of_equal.

rewrite bool_prop_iff; split; intro.

bool_connectors; destruct H as (H1,(H2,H3)).

symmetry in H2; symmetry in H1.

assert (missing x K' = 0%nat) by (apply beq_nat_eq; assumption).

assert (missing y (~~K') = 0%nat) by (apply beq_nat_eq; assumption).

assert (x && ~~K' = BVfalse n) by (apply occurrences_true_0; assumption).

assert (y && (~~(~~K')) = BVfalse n) by (apply occurrences_true_0; assumption).

clear H1 H2 H H0.

rewrite involutive in H5.

rewrite BVeq_eq_true in H3.

split; rewrite BVeq_eq_true.

symmetry.

rewrite H3.

rewrite demorgan.

aac_rewrite (BVand_empty K' y); [| aac_rewrite H5; reflexivity ].

rewrite demorgan; BVsimpl.

rewrite BVand_or_distrib_r.

rewrite BVand_neg_false.

replace (K' && x) with (x && ~~ ~~ K') by (rewrite involutive; aac_reflexivity).

rewrite (BVand_empty x (~~K')); [ aac_reflexivity | assumption ].

symmetry.

rewrite H3.

aac_rewrite (BVand_or_distrib_r (~~K') (K'&&~~x) y).

aac_rewrite (BVand_neg_false K').

rewrite BVand_false.

aac_rewrite (BVand_empty y K'); [ reflexivity | assumption ].

assert
(forall n (v : Bvector n), v = BVfalse _ -> occurrences v true = 0)%nat
as BVfalse_occs_0.

intros.

setoid_rewrite occurrences_Vconst.

assumption.

rewrite andb_true_iff in H; destruct H.

rewrite BVeq_eq_true in H; rewrite BVeq_eq_true in H0.

repeat rewrite andb_true_iff; split; [|split];
try (rewrite beq_nat_true_iff; apply BVfalse_occs_0);
try rewrite BVeq_eq_true.

rewrite H.

aac_rewrite (BVand_neg_false K').

apply BVand_false.

rewrite involutive.

rewrite H0.

aac_rewrite (BVand_neg_false K').

apply BVand_false.

rewrite H.

BVsimpl.

rewrite BVand_or_distrib_r.

rewrite (BVand_neg_false K').

rewrite H0.

aac_rewrite <- (BVand_or_distrib_r K K' (~~K')) in_right.

rewrite BVor_neg_true; aac_reflexivity.

Qed.

correct K' (missing K' K) (missing K K') will recover K from K' with probability [1/](comb (missing K' K) (occs_1 K') * comb (missing K K') (occs_0 K').
Theorem correct_eq :
forall n (K K' : Bvector n),
mu (correct K' (missing K' K) (missing K K')) (fun x => B2U (BVeq K x)) ==
[1/] (comb (missing K' K) (occs_1 K') * comb (missing K K') (n - occs_1 K')).

Proof.

intros; unfold correct; si.

rewrite <- Umult_Nnth;
[ | apply comb_le_0_lt; apply missing_le_occs
| apply comb_le_0_lt; apply missing_le_not_occs ].

rewrite (range_cover (BVrandom_k_mask_range (missing K' K) K') (missing_cover K')).

setoid_rewrite
(range_cover (BVrandom_k_mask_range (missing K K') (Bneg n K'))
(missing_cover (Bneg n K'))).

setoid_rewrite <- mu_stable_mult; unfold fmult.

setoid_rewrite correct_success_eq.

setoid_rewrite <- fmult_def; setoid_rewrite mu_stable_mult.

setoid_rewrite Umult_sym.

setoid_rewrite <- fmult_def; setoid_rewrite mu_stable_mult.

apply Umult_eq_compat.

| rewrite Bneg_occs_1; apply missing_le_not_occs
| apply missing_BVand ].

rewrite <- beq_nat_refl.

rewrite Bneg_occs_1.

reflexivity.