Library RandomBV

Probabilistic algorithms on bit vectors


Add Rec LoadPath "../../ALEA/src" as ALEA.


Require Import Arith.

Require Export BVlib.

Require Export DistrTactic.

Require Export IsDiscrete.


Ltac si := repeat simpl_mu_rewrite idtac.


Set Implicit Arguments.


The space of bit vectors is discrete and decidable


Fixpoint bvpoints n code :=
  match n return Bvector n with
    | 0 => Vnil
    | S m => let (i,j) := bij_n_nxn code in Vcons bool (points i) m (bvpoints m j)
  end.


Instance bvector_discrete : forall n, Discrete_domain (Bvector n).

Proof.

  intro.

  apply Build_Discrete_domain with (points := bvpoints n).

  induction n; intros.

  simpl.
rewrite (V0_eq _ x). apply ex_intro with (x:=0%nat); reflexivity.
  VS_rewrite x.

  simpl.

  destruct (points_surj x_hd) as (hd,Hhd).

  destruct (IHn x_tl) as (tl,Htl).

  destruct (bij_surj hd tl) as (code,Hn).

  exists code.

  rewrite Hn; rewrite Hhd; rewrite Htl; reflexivity.

Qed.


Instance bvector_decidable : forall n, DecidEq (Bvector n) :=
  fun n => Build_DecidEq (@Bvector_eq_dec n).


Distributions on bit vectors


Uniformly choose a vector of length n


Fixpoint BVrandom n :=
  match n return distr (Bvector n) with
    | 0 => Munit Bnil
    | S m => Mlet Flip (fun hd => Mlet (BVrandom m) (fun tl => Munit (Vcons hd tl)))
  end.


Theorem BVrandom_eq : forall n (x : Bvector n),
  mu (BVrandom n) (fun b => B2U (BVeq b x)) == [1/2]^n.

Proof.

  induction x.

  reflexivity.

  simpl BVrandom; si.

  case a; rewrite IHx; si; auto.

Qed.


Theorem BVxor_noise_si : forall n k f,
  mu (BVrandom n) (fun m => f (BVxor n k m)) == mu (BVrandom n) f.

Proof.

  intros n k f.

  induction k.

  reflexivity.

  simpl BVrandom; si.

  setoid_rewrite <- IHk at 3 4.

  case a; auto.

Qed.


Theorem BVxor_noise : forall n k,
  Mlet (BVrandom n) (fun m => Munit (BVxor n k m)) == BVrandom n.

Proof.

  intros n k f; si; apply BVxor_noise_si.

Qed.


Uniformly choose a vector of length n with k bits sets to 1


Fixpoint BVrandom_k n k :=
  match n return distr (Bvector n) with
    | 0 => Munit Bnil
    | S m =>
        
        Mchoice (k */ [1/]1+m)
          (Mlet (BVrandom_k m (k-1)%nat) (fun tl => Munit (Vcons true tl)))
          (Mlet (BVrandom_k m k) (fun tl => Munit (Vcons false tl)))
  end.


Theorem BVrandom_k_range :
  forall n k,
    (k <= n)%nat ->
    range (fun x => @occurrences n x true = k) (BVrandom_k n k).

Proof.

  unfold range.

  intro n; induction n; intros.

  destruct k; [|elim (le_Sn_O _ H)].

  simpl.

  apply H0.

  reflexivity.

  simpl BVrandom_k; si.

  assert ((0 < k)%nat ->
    0%U == mu (BVrandom_k n (k-1)) (fun x : Bvector n => f (Vcons true x)))
    as HX.

    intros.

    apply IHn; try omega; intros.

    apply H0; simpl; rewrite H2; omega.

  assert ((k < S n)%nat ->
    0 == mu (BVrandom_k n k) (fun x : Bvector n => f (Vcons false x)))
    as HY.

    intros.

    apply IHn; try omega; intros.

    apply H0; simpl; rewrite H2; omega.

  assert (k = S n \/ k < S n)%nat as [H3|H3] by omega.

  rewrite <- HX by omega.

  rewrite H3.

  setoid_replace (S n */ [1/]1+n) with 1 by auto.

  repeat Usimpl; auto.

  rewrite <- (HY H3); repeat Usimpl.

  assert (k > 0 \/ k = 0)%nat as [H4|H4] by omega.

    rewrite <- (HX H4); auto.

    rewrite H4; auto.

Qed.


Theorem BVrandom_k_range_dec : forall n k, dec (fun x => @occurrences n x true = k).

Proof.

  intros; intro; auto with arith.

Qed.


BVrandom_k always terminates and returns a vector of n bits with k occurrences of true.

Theorem BVrandom_k_correct : forall n k, k <= n ->
  mu (BVrandom_k n k) (fun w => B2U (beq_nat (@occurrences n w true) k)) == 1.

Proof.

intro n.

induction n; intros.


 assert (k=0)%nat as Hk.

 auto with arith.

 rewrite Hk.

 auto.


 simpl BVrandom_k; si.

 unfold carac; simpl occurrences.

 destruct k.


  repeat Usimpl.

  apply IHn.

  auto with arith.


  simpl pred.

  elim (Compare.le_decide k n (le_S_n _ _ H)); intro.

  unfold carac in IHn.

  rewrite <- pred_of_minus; simpl pred.

  rewrite (IHn k).

  rewrite (IHn (S k)).

  auto.

  auto with arith.

  auto with arith.

  rewrite b.

  rewrite (Nmult_ge_Sn_Unth (le_n (S n))).

  repeat Usimpl.

  rewrite <- pred_of_minus; simpl pred.

  apply IHn.

  auto.


Qed.


Theorem O_lt_mult : forall n m, (0 < n -> 0 < m -> 0 < n*m)%nat.

Proof.

  intros.

  destruct n; destruct m; try (assert False as F by omega; elim F).

  apply lt_0_Sn.

Qed.


Theorem BVrandom_k_eq_min :
  forall n w k,
   mu (BVrandom_k n k) (fun w' => B2U (@BVeq n n w' w)) ==
   (if beq_nat (occurrences w true) (Min.min n k)%nat then
      Unth (pred (comb k n))
    else 0).

Proof.

  induction w; intro k.


  si; simpl.

  destruct k; auto.


  simpl BVrandom_k; si.

  assert (forall b (w':Bvector n), B2U (BVeq (Vcons b w') (Vcons a w)) ==
                    B2U (eqb b a) * B2U (BVeq w' w)) by
    (intros [|] w'; destruct a; simpl; repeat Usimpl; reflexivity).

  setoid_rewrite H; clear H; rewrite <- 2!fmult_def; rewrite 2!mu_stable_mult.

  rewrite 2! IHw.

  destruct a; simpl eqb; simpl B2U; repeat Usimpl.

   assert (k > S n \/ k <= S n)%nat as [Hlt|Hle] by omega.

    rewrite (Min.min_l n (k-1)%nat) by omega.

    rewrite (Min.min_l (S n) k) by omega.

    simpl occurrences; simpl beq_nat.

    case (beq_nat (occurrences w true) n); auto.

    rewrite comb_not_le by omega.

    rewrite comb_not_le by omega.

    simpl; repeat Usimpl.

    apply Nmult_ge_Sn_Unth; omega.

    rewrite (Min.min_r n (k-1)%nat) by omega.

    rewrite (Min.min_r (S n) k) by omega.

    simpl occurrences.

    destruct k; [auto|].

    simpl beq_nat.

    replace (k-0)%nat with k by omega.

    case (beq_nat (occurrences w true) k); auto.

    rewrite <- Nmult_Umult_assoc_left by auto.

    replace ([1/]1+n) with ([1/](S n)) by auto.

    rewrite Umult_Nnth; [| omega | apply comb_le_0_lt; omega ].

    change ([1/] comb (S k) (S n)) with (1 */ [1/] comb (S k) (S n)).

    apply Nmult_Unth_eq.

    rewrite mult_1_l.

    rewrite <- S_pred with (m := O) by (apply comb_le_0_lt; omega).

    replace (S k - 1)%nat with k by omega.

    rewrite <- S_pred with (m := O) by (apply O_lt_mult; [|apply comb_le_0_lt];omega).

    aac_rewrite comb_incr_k; aac_rewrite comb_incr_n; aac_reflexivity.


   simpl occurrences.

   assert (n < k \/ k <= n)%nat as [Hlt|Hle] by omega.

    rewrite (Min.min_l n k) by omega.

    rewrite (Min.min_l (S n) k) by omega.

    case_eq (beq_nat (occurrences w true) (S n)); intro.

    assert (occurrences w true <= n)%nat by apply occurrences_length_le.

    rewrite (beq_nat_eq _ _ (eq_sym H)) in H0.

    assert False as F by omega; elim F.

    setoid_replace (k */ [1/]1+n) with 1; auto.

    rewrite (Min.min_r (S n) k) by omega.

    rewrite (Min.min_r n k) by omega.

    case (beq_nat (occurrences w true) k); [|auto].

    rewrite Uinv_Nmult.

    rewrite <- Nmult_Umult_assoc_left by (apply Nmult_def_Unth_le; omega).

    replace ([1/]1+n) with ([1/](S n)) by auto.

    rewrite Umult_Nnth; [| omega | apply comb_le_0_lt; omega ].

    change ([1/] comb k (S n)) with (1 */ [1/] comb k (S n)).

    apply Nmult_Unth_eq.

    rewrite mult_1_l.

    rewrite <- S_pred with (m := O) by (apply comb_le_0_lt; omega).

    rewrite <- S_pred with (m := O) by (apply O_lt_mult; [|apply comb_le_0_lt]; omega).

    aac_rewrite comb_incr_n.

    aac_reflexivity.

Qed.


Theorem BVrandom_k_eq :
  forall n k, (k <= n)%nat -> forall w,
   mu (BVrandom_k n k) (fun w' => B2U (@BVeq n n w' w)) ==
   (if beq_nat (occurrences w true) k then
      Unth (pred (comb k n))
    else 0).

Proof.

  intros.

  rewrite <- (Min.min_r n k H) at 2.

  apply BVrandom_k_eq_min.

Qed.


Uniformly choose a vector with k bits set to 1 among those set in a mask


Fixpoint BVrandom_k_mask n (k:nat) (mask : Bvector n) : distr (Bvector n) :=
  match mask with
    | Vnil => Munit Vnil
    | Vcons false _ tl =>
        Mlet (BVrandom_k_mask k tl) (fun tl => Munit (Vcons false tl))
    | Vcons true _ tl =>
        Mlet (bernoulli (Nmult k (Unth (occs_1 tl)))) (fun r =>
          if r then
            Mlet (BVrandom_k_mask (pred k) tl) (fun tl => Munit (Vcons true tl))
          else
            Mlet (BVrandom_k_mask k tl) (fun tl => Munit (Vcons false tl)))
  end.


Theorem BVrandom_k_mask_range : forall n k (mask : Bvector n),
  Prog.range (fun x => missing x mask = O) (BVrandom_k_mask k mask).

Proof.

  intros n k mask f.

  generalize k f; induction mask; intros.

  simpl; apply H.

  reflexivity.

  simpl; destruct a; si.

  rewrite <- (IHmask (fun x : Bvector n => f (Vcons false x)) (pred k0)).

  rewrite <- (IHmask (fun x : Bvector n => f (Vcons false x)) k0).

  repeat Usimpl; reflexivity.

  intros; apply H.

  unfold missing, occs_1; simpl; exact H0.

  intros; apply H.

  unfold missing, occs_1; simpl; exact H0.

  apply (IHmask (fun x : Bvector n => f (Vcons false x)) k0).

  intros; apply H.

  unfold missing, occs_1; simpl; exact H0.

Qed.


Ltac omega_absurd := exfalso; omega.

Ltac S_rewrite x := destruct x; [ omega | idtac ].


Theorem BVrandom_k_mask_eq :
  forall n (mask : Bvector n),
   forall k, (k <= occs_1 mask)%nat ->
    forall w, missing w mask = O ->
      mu (BVrandom_k_mask k mask) (fun w' => B2U (@BVeq n n w' w)) ==
      (if beq_nat (occs_1 w) k then Unth (pred (comb k (occs_1 mask))) else 0).

Proof.

  intros n mask; induction mask; intros.


  change (occs_1 Vnil) with O in H.

  replace k with O by omega.

  rewrite (V0_eq _ w).

  simpl; repeat Usimpl; reflexivity.


  destruct a.


  Focus 2.

  VS_rewrite w.

  simpl BVrandom_k_mask; si.

  destruct w_hd; try discriminate H0.

  change (occs_1 (Vcons false mask)) with (occs_1 mask) in H.

  simpl.

  rewrite IHmask by assumption.

  change (occs_1 (Vcons false w_tl)) with (occs_1 w_tl).

  case (occs_1 w_tl); try (repeat Usimpl; reflexivity).


  change (occs_1 (Vcons true mask)) with (S (occs_1 mask)) in H.

  VS_rewrite w.

  rewrite missing_true in H0.

  simpl BVrandom_k_mask; si.

  assert (k = S (occs_1 mask) \/ k <= occs_1 mask)%nat as H_ by omega;
    destruct H_ as [Heq|Hle].


  rewrite Heq; rewrite Nmult_ge_Sn_Unth by omega; repeat Usimpl.

  rewrite <- pred_Sn.

  destruct w_hd; simpl BVeq.

    rewrite IHmask by omega.

    repeat rewrite comb_n_n.

    reflexivity.

    si.

    assert (occs_1 w_tl <= occs_1 mask)%nat by (apply missing_occs_le; assumption).

    change (occs_1 (Vcons false w_tl)) with (occs_1 w_tl).

    rewrite if_beq_nat_nat_eq_dec.

    case (eq_nat_dec (occs_1 w_tl) (S (occs_1 mask))); intros; [omega_absurd|reflexivity].


  destruct w_hd; simpl BVeq.


  rewrite IHmask by omega; si.   
rewrite Umult_zero_right.
  rewrite Uplus_zero_right.

  assert (k = 0 \/ k > 0)%nat as Hk by omega; destruct Hk as [Hk|Hk].

  rewrite Hk; simpl; repeat Usimpl; reflexivity.

  replace (beq_nat (occs_1 w_tl) (pred k)) with (beq_nat (occs_1 (Vcons true w_tl)) k).

  case (beq_nat (occs_1 (Vcons true w_tl)) k); try (repeat Usimpl; reflexivity).

  rewrite <- Nmult_Umult_assoc_left by (apply Nmult_def_Unth_le; omega).

  change (occs_1 mask) with (pred (S (occs_1 mask))).

  rewrite Umult_Nnth;
    [ | omega | apply comb_le_0_lt; omega ].

  rewrite Nmult_Unth_factor with (m2 := pred (comb k (occs_1 (Vcons true mask)))).

  reflexivity.

  rewrite <- S_pred with (m:=O) by
    (apply comb_le_0_lt; unfold occs_1 in *|-*; simpl; omega).

  change (occs_1 (Vcons true mask)) with (S (occs_1 mask)).

  S_rewrite k.

  change (pred (S k)) with k.

  change (pred (S (occs_1 mask))) with (occs_1 mask).

  aac_rewrite comb_incr_k.

  rewrite <- S_pred with (m:=O).

  aac_rewrite comb_incr_n.

  aac_reflexivity.

  assert ((comb k (occs_1 mask) > 0)%nat) by (apply comb_le_0_lt; omega).

  S_rewrite (comb k (occs_1 mask)).

  simpl; auto with arith.

  destruct k; [ inversion Hk | reflexivity ].


  rewrite IHmask by omega.

  rewrite mu_fzero_eq; repeat Usimpl.   
change (occs_1 (Vcons false w_tl)) with (occs_1 w_tl).
  case (beq_nat (occs_1 w_tl)); try (repeat Usimpl; reflexivity).

  rewrite Uinv_Nmult.

  rewrite <- Nmult_Umult_assoc_left by (apply Nmult_def_Unth_le; omega).

  change (occs_1 mask) with (pred (S (occs_1 mask))).

  rewrite Umult_Nnth; [ | omega | apply comb_le_0_lt; omega ].

  apply Nmult_Unth_factor with (m2 := pred (comb k (occs_1 (Vcons true mask)))).

  change (occs_1 (Vcons true mask)) with (S (occs_1 mask)).

  rewrite <- pred_Sn.

  assert ((comb k (occs_1 mask) > 0)%nat) by (apply comb_le_0_lt; omega).

  repeat rewrite <- S_pred with (m:=O).

  aac_rewrite comb_incr_n; aac_reflexivity.

  simpl; S_rewrite (comb k (occs_1 mask)); simpl; auto with arith.

  apply comb_le_0_lt.

  change (occs_1 (Vcons true mask)) with (S (occs_1 mask)); omega.


Qed.


Correction of a vector by adding and removing 1-bits


correct k' ki km returns a correction of k' by flipping ki 1-bits and km 0-bits.
Definition correct : forall n, Bvector n -> nat -> nat -> distr (Bvector n) :=
  fun n K' ki km =>
    Mlet (BVrandom_k_mask ki K') (fun i =>
    Mlet (BVrandom_k_mask km (~~ K')) (fun m =>
      Munit ((K' && (~~ i)) || m))).


Theorem missing_cover : forall n (mask : Bvector n),
 cover (fun x => missing x mask = 0%nat) (fun x => B2U (beq_nat (missing x mask) 0%nat)).

Proof.

  intros n mask x.

  split; intro.

  rewrite H; reflexivity.

  rewrite <- beq_nat_neq by assumption.

  reflexivity.

Qed.


Lemma or_eq_missing : forall n (K K' a : Bvector n),
  missing a (~~ K') = 0%nat ->
  B2U (BVeq (K' || K) (K' || a)) ==
  B2U (BVeq a (K && (~~ K'))).

Proof.

  induction n; intros;
  [ repeat setoid_rewrite V0_eq | VS_rewrite a; VS_rewrite K; VS_rewrite K' ].

  rewrite (V0_eq _ a); reflexivity.

  simpl.

  simpl Bneg in H.

  assert (B2U (BVeq (K'_tl || K_tl) (K'_tl || a_tl)) ==
          B2U (BVeq a_tl (K_tl && (~~ K'_tl)))).

    apply IHn.

    apply (missing_0_tl _ _ H).

  destruct K_hd; destruct K'_hd; destruct a_hd; simpl; try rewrite H0;
  first [ reflexivity | rewrite missing_S in H; discriminate ].

Qed.


Theorem cancel_missing : forall n (K K' : Bvector n),
  mu
    (Mlet (BVrandom_k_mask (missing K K') (~~ K')) (fun x => Munit (K' || x)))
    (fun v => B2U (BVeq (K' || K) v))
  == [1/] (comb (missing K K') (n - occs_1 K')).

Proof.

  intros.

  rewrite <- Bneg_occs_1.

  rewrite Mlet_simpl.

  setoid_rewrite Munit_simpl.

  rewrite (range_eq (BVrandom_k_mask_range (missing K K') (~~ K')))
    with (g := (fun x : Bvector n => B2U (BVeq x (K && (~~ K')))))
    by apply or_eq_missing.

  rewrite BVrandom_k_mask_eq.

  rewrite <- beq_nat_refl; reflexivity.

  unfold missing, occs_1.

  rewrite (_ : AAC.Commutative eq (BVand (n:=n))); apply occs_1_BVand.

  apply missing_BVand.

Qed.


Proposition Of_equal : forall A (f : A -> U) x y, x = y -> f x == f y.

Proof.

  intros; rewrite H; reflexivity.

Qed.


Lemma Umult_B2U_andb : forall x y, (B2U x) * (B2U y) == B2U (andb x y).

Proof.

  intros; unfold B2U; simpl.

  destruct x; destruct y; simpl; repeat Usimpl; reflexivity.

Qed.


Lemma correct_success_eq : forall n (K K' : Bvector n),
  forall x y,
      B2U (beq_nat (missing x K') 0) *
      (B2U (beq_nat (missing y (~~ K')) 0) *
       B2U (BVeq K ((K' && ~~ x) || y)))
   == B2U (BVeq x (K' && ~~ K)) *
      B2U (BVeq y (K && ~~ K')).

Proof.

  intros.

  repeat rewrite Umult_B2U_andb.

  apply Of_equal.

  rewrite bool_prop_iff; split; intro.

  bool_connectors; destruct H as (H1,(H2,H3)).

  symmetry in H2; symmetry in H1.

  assert (missing x K' = 0%nat) by (apply beq_nat_eq; assumption).

  assert (missing y (~~K') = 0%nat) by (apply beq_nat_eq; assumption).

  assert (x && ~~K' = BVfalse n) by (apply occurrences_true_0; assumption).

  assert (y && (~~(~~K')) = BVfalse n) by (apply occurrences_true_0; assumption).

  clear H1 H2 H H0.

  rewrite involutive in H5.

  rewrite BVeq_eq_true in H3.

  split; rewrite BVeq_eq_true.

  symmetry.

  rewrite H3.

  rewrite demorgan.

  aac_rewrite (BVand_empty K' y); [| aac_rewrite H5; reflexivity ].

  rewrite demorgan; BVsimpl.

  rewrite BVand_or_distrib_r.

  rewrite BVand_neg_false.

  replace (K' && x) with (x && ~~ ~~ K') by (rewrite involutive; aac_reflexivity).

  rewrite (BVand_empty x (~~K')); [ aac_reflexivity | assumption ].

  symmetry.

  rewrite H3.

  aac_rewrite (BVand_or_distrib_r (~~K') (K'&&~~x) y).

  aac_rewrite (BVand_neg_false K').

  rewrite BVand_false.

  aac_rewrite (BVand_empty y K'); [ reflexivity | assumption ].

  assert
    (forall n (v : Bvector n), v = BVfalse _ -> occurrences v true = 0)%nat
    as BVfalse_occs_0.

    intros.

    setoid_rewrite occurrences_Vconst.

    assumption.

  rewrite andb_true_iff in H; destruct H.

  rewrite BVeq_eq_true in H; rewrite BVeq_eq_true in H0.

  repeat rewrite andb_true_iff; split; [|split];
  try (rewrite beq_nat_true_iff; apply BVfalse_occs_0);
  try rewrite BVeq_eq_true.

  rewrite H.

  aac_rewrite (BVand_neg_false K').

  apply BVand_false.

  rewrite involutive.

  rewrite H0.

  aac_rewrite (BVand_neg_false K').

  apply BVand_false.

  rewrite H.

  BVsimpl.

  rewrite BVand_or_distrib_r.

  rewrite (BVand_neg_false K').

  rewrite H0.

  aac_rewrite <- (BVand_or_distrib_r K K' (~~K')) in_right.

  rewrite BVor_neg_true; aac_reflexivity.

Qed.


correct K' (missing K' K) (missing K K') will recover K from K' with probability [1/](comb (missing K' K) (occs_1 K') * comb (missing K K') (occs_0 K').
Theorem correct_eq :
  forall n (K K' : Bvector n),
    mu (correct K' (missing K' K) (missing K K')) (fun x => B2U (BVeq K x)) ==
    [1/] (comb (missing K' K) (occs_1 K') * comb (missing K K') (n - occs_1 K')).

Proof.

  intros; unfold correct; si.

  rewrite <- Umult_Nnth;
    [ | apply comb_le_0_lt; apply missing_le_occs
      | apply comb_le_0_lt; apply missing_le_not_occs ].

  rewrite (range_cover (BVrandom_k_mask_range (missing K' K) K') (missing_cover K')).

  setoid_rewrite
    (range_cover (BVrandom_k_mask_range (missing K K') (Bneg n K'))
                 (missing_cover (Bneg n K'))).

  setoid_rewrite <- mu_stable_mult; unfold fmult.

  setoid_rewrite correct_success_eq.

  setoid_rewrite <- fmult_def; setoid_rewrite mu_stable_mult.

  setoid_rewrite Umult_sym.

  setoid_rewrite <- fmult_def; setoid_rewrite mu_stable_mult.

  apply Umult_eq_compat.

  rewrite BVrandom_k_mask_eq; [
    | rewrite Bneg_occs_1; apply missing_le_not_occs
    | apply missing_BVand ].

  rewrite <- beq_nat_refl.

  rewrite Bneg_occs_1.

  reflexivity.

  rewrite BVrandom_k_mask_eq; [
    | apply missing_le_occs
    | rewrite (_ : AAC.Commutative eq (BVand (n:=n))); apply missing_BVand ].

  rewrite <- beq_nat_refl.

  reflexivity.

Qed.