Theory Separation_Lenses.SLens_Pullback

(* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
   SPDX-License-Identifier: MIT *)

theory SLens_Pullback
  imports
    Shallow_Separation_Logic.Assertion_Language 
    Shallow_Separation_Logic.Weak_Triple
    Shallow_Micro_Rust.Shallow_Micro_Rust
    Shallow_Separation_Logic.Triple
    Shallow_Separation_Logic.Function_Contract 
    SLens
begin

section ‹Pullbacks along Separation lenses›

text‹This is the central theory around separation lenses: We show that contracts and proofs
can be extended / 'pulled back' along a separation lens. While deceptively short, this theory
is very effective in capturing technical boilerplate involved in extending interface
implementations from smaller to larger separation algebras. For example, if a separation algebra
is needed that implements interfaces ‹A› and ‹B›, one can independently construct implementations
‹sA and ‹sB of ‹A› and ‹B›, and then pull them back along ‹sA × sB to establish ‹sA × sB
as a separation algebra implementing ‹both› ‹A› and ‹B›. Writing the necessary boilerplate by
hand is not necessarily difficult, but tedious for complex interfaces or are large number of interfaces
that have to be composed as above.›

definition pull_back_assertion :: ('s, 't) lens  't assert  's assert
  where pull_back_assertion l ξ = {σ. lens_view l σ  ξ}
adhoc_overloading pull_back_const  pull_back_assertion

lemma pull_back_assertion_compose:
  shows pull_back_assertion (l0 L l1) ξ = pull_back_assertion l0 (pull_back_assertion l1 ξ)
  by (clarsimp simp add: pull_back_assertion_def asat_def compose_lens_components)

definition pull_back_relation :: ('s, 't) lens  ('t  'v × 't  bool)
   ('s  'v × 's  bool) where
  pull_back_relation l R  λσ (v, σ').
      (let τ = lens_view l σ in
          τ'. R τ (v, τ')  lens_update l τ' σ = σ')
adhoc_overloading pull_back_const  pull_back_relation

definition is_lifted_striple_context where
  is_lifted_striple_context l Γ Θ 
      is_lifted_yield_handler l (yh Γ) (yh Θ)

definition is_canonical_lifted_striple_context where
  is_canonical_lifted_striple_context l Γ Θ 
      (yh Θ = canonical_pull_back_yield_handler l (yh Γ))

context slens
begin

lift_definition pull_back_striple_context :: ('t, 'abort, 'i, 'o) striple_context  ('s, 'abort, 'i, 'o) striple_context
  is λΓ. make_striple_context_raw (canonical_pull_back_yield_handler l (yield_handler_raw Γ))
   by (simp add: lens_valid canonical_pull_back_yield_handler_log_preserving
     is_valid_striple_context_def)

lemma pull_back_striple_context_yield_handler:
  shows yh (pull_back_striple_context Γ) = canonical_pull_back_yield_handler l (yh Γ)
  by (transfer, simp)

lemma pull_back_striple_context_no_yield[simp]:
  shows pull_back_striple_context striple_context_no_yield = striple_context_no_yield
  by (transfer, simp add: lens_valid striple_context_raw_no_yield_def)

end

adhoc_overloading pull_back_const  slens.pull_back_striple_context

definition pull_back_contract where
  pull_back_contract l   make_function_contract_with_abort
      (l¯ (function_contract_pre )) 
      (λr. l¯ (function_contract_post  r))
      (λr. l¯ (function_contract_abort  r))
adhoc_overloading pull_back_const  pull_back_contract

named_theorems slens_pull_back_simps
named_theorems slens_pull_back_intros

context slens
begin

text‹Pullback of assertions along separation lenses commutes with separation conjunction:›

lemma pull_back_assert_Union[slens_pull_back_simps]:
  shows l¯ (x. ξ x) = (x. l¯ (ξ x))
  by (auto simp add: pull_back_assertion_def)

lemma pull_back_assertion_univ[slens_pull_back_simps]:
  shows l¯ UNIV = UNIV
  by (simp add: pull_back_assertion_def)

lemma pull_back_assertion_false[slens_pull_back_simps]:
  shows l¯ {} = {}
  by (simp add: pull_back_assertion_def)

lemma pull_back_assertion_pure[slens_pull_back_simps]:
  shows l¯ P = P
  by (simp add: apure_def pull_back_assertion_def)

lemma pull_back_asepconj[slens_pull_back_simps]:
  fixes ξ τ :: 't assert
  shows l¯ (ξ  τ) = (l¯ ξ)  (l¯ τ)
  apply (clarsimp simp add: asepconj_def pull_back_assertion_def asat_def aentails_def; safe)
  apply (metis slens_lift_decomposition slens_valid)
  apply (meson slens_valid slens_view_local1 slens_view_local2)
  done

lemma pull_back_asepconj_multi[slens_pull_back_simps]:
  shows l¯ (⋆⋆{# ξ x . x  y #}) = ⋆⋆{# l¯ (ξ x) . x  y #}
  by (induction y; clarsimp simp add: slens_pull_back_simps simp add: asepconj_simp)

corollary pull_back_asepconj_univ[slens_pull_back_simps]:
  fixes ξ τ :: 't assert
  shows l¯ (ξ  UNIV) = (l¯ ξ)  UNIV
  by (simp add: slens_pull_back_simps)

lemma pull_back_ucincl[slens_pull_back_intros]:
  assumes ucincl π
  shows ucincl (l¯ π)
proof -
  have l¯ π = l¯ π  UNIV
    using assms by (clarsimp simp add: ucincl_alt simp flip: slens_pull_back_simps)
  from this show ?thesis using ucincl_alt
    by auto
qed

lemma pull_back_aentails[slens_pull_back_simps]:
  shows (l¯ α  l¯ β)  (α  β)
proof -
  have x. lens_view l (lens_update l x 0) = x
    by (meson lens_laws lens_valid)
  from this have x. y. lens_view l y = x
    by blast
  from this show ?thesis
    by (auto simp add: asat_def aentails_def pull_back_assertion_def) metis
qed

lemma pull_back_asat_adjoint[slens_pull_back_simps]:
  shows (lens_view l) σ  ξ  σ  l¯ ξ
  by (clarsimp simp add: pull_back_assertion_def asat_def)

lemma pull_back_assertion_int[slens_pull_back_simps]:
  shows l¯ (φ  ψ) = l¯ φ  l¯ ψ
  by (auto simp add: pull_back_assertion_def asat_def)

lemma pull_back_contract_with_abort [slens_pull_back_simps]:
  pull_back_contract l (make_function_contract_with_abort pre post ab) 
      make_function_contract_with_abort
      (l¯ pre) (λr. l¯ (post r)) (λr. l¯ (ab r))
  by (clarsimp simp add: pull_back_contract_def)

lemma pull_back_contract [slens_pull_back_simps]:
  pull_back_contract l (make_function_contract pre post) 
      make_function_contract
      (l¯ pre) (λr. l¯ (post r))
  by (clarsimp simp add: bot_fun_def slens_pull_back_simps)

lemma pull_back_aentailsI[slens_pull_back_intros]:
  assumes α  β
  shows l¯ α  l¯ β
  using assms pull_back_aentails by simp

lemmas pull_back_aentailsE = pull_back_aentails[elim_format]

text‹The following is central: Assertion triples can be pulled back along separation lenses:›

lemma pull_back_atriple[slens_pull_back_intros]:
  assumes φ  (lens_view l s, τ') weak ψ
  shows l¯ φ  (s,lens_update l τ' s) weak l¯ ψ
  using assms
proof -
  assume φ  (lens_view l s, τ') weak ψ
  let ?p = lens_view l
  { fix ξ assume ucincl ξ and s  l¯ φ  ξ
    then obtain a b where s = a + b and a  b and a  l¯ φ and b  ξ
      using asepconjE by blast
    from this have ?p s = ?p a + ?p b and ?p a  ?p b
      using slens_valid slens_view_local2 slens_view_local1 by blast+
    let ?pb = ?p b
    let ?ξ' = {?pb}  
    have ucincl ?ξ'
      by (simp add: ucincl_UNIV ucincl_asepconjR)
    moreover have ?pb  ?ξ'
      by (clarsimp simp add: asat_def asepconj_def) force
    moreover from this and ?p s = ?p a + ?p b and ?p a  ?p b a  l¯ φ
      have ?p s  φ  ?ξ'
      by (metis asat_def asepconjI mem_Collect_eq pull_back_assertion_def)
    moreover from this and ucincl ?ξ' and φ  (lens_view l s, τ') weak ψ have
      τ'  ψ  ?ξ'
      using atriple_def by blast
    moreover from this obtain a' b' where τ' = a' + b' and a'  b' and a'  ψ
      and b'  ?ξ'
      using asepconjE by blast
    moreover from this obtain b0 where b' = ?pb + b0 and ?pb  b0
      by (metis asat_def asepconjE singletonD)
    moreover from calculation have lens_update l τ' s = lens_update l a' a + lens_update l b' b
      using slens_valid slens_update_general by (metis a  b s = a + b)
    moreover from calculation have lens_update l a' a  lens_update l b' b
      by (metis a  b disjoint_sym sepalg_apart_plus2 slens_lens_laws(1) slens_update_local4)
    moreover from calculation have lens_update l a' a  l¯ ψ
      by (metis lens_laws_update(1) local.lens_valid pull_back_asat_adjoint)
    moreover from calculation have lens_update l b' b = b + lens_update l b0 0
      by (metis slens_lens_laws(2) slens_update_disjoint)
    moreover from this have lens_update l b' b  ξ
      by (metis b  ξ ucincl ξ asat_weaken calculation(10) slens_complement_disj)
    ultimately have lens_update l τ' s  l¯ ψ  ξ
      by (metis asat_def asepconjI) }
  from this show l¯ φ  (s,lens_update l τ' s) weak l¯ ψ
    by (meson asat_def atripleI)
qed

lemma pull_back_striple[slens_pull_back_intros]:
  fixes e :: ('t, 'v, 'r, 'abort, 'i prompt, 'o prompt_output) expression
  assumes T: Γ; φ  e  weak ψ  ξ  θ
      and is_canonical_lifted_striple_context l Γ Θ
    shows Θ; l¯ φ  l¯ e weak (λv. l¯ (ψ v))  (λr. l¯ (ξ r))  (λr. l¯ (θ r))
proof -
  from is_canonical_lifted_striple_context l Γ Θ have
    YH: yh Θ = canonical_pull_back_yield_handler l (yh Γ)
    unfolding is_canonical_lifted_striple_context_def by simp
  note assms = YH T
  show ?thesis
proof (intro stripleI)
  fix v s s'
  assume s v yield_handler Θ,l¯ e (v,s')
  from this obtain τ' where
     s' = lens_update l τ' s and lens_view l s v yh Γ, e (v, τ')
    using YH expression_pull_back_eval_value_canonical[OF lens_valid] by metis
  from this and Γ; φ  e  weak ψ  ξ  θ obtain φ  (lens_view l s,τ') weak ψ v
    by (meson stripleE_value)
  from this and s' = lens_update l τ' s and pull_back_atriple
    show l¯ φ  (s,s') weak l¯ (ψ v)
    by simp
next
  fix r s s'
  assume s r yield_handler Θ,l¯ e (r,s')
  from this obtain τ' where
     s' = lens_update l τ' s and lens_view l s r yh Γ, e (r, τ')
    using YH expression_pull_back_eval_return_canonical[OF lens_valid] by metis
  from this and Γ; φ  e  weak ψ  ξ  θ obtain φ  (lens_view l s,τ') weak ξ r
    by (meson stripleE_return)
  from this and s' = lens_update l τ' s and pull_back_atriple
    show l¯ φ  (s,s') weak l¯ (ξ r)
    by simp
next
  fix a s s'
  assume s a yield_handler Θ,l¯ e (a,s')
  from this obtain τ' where
     s' = lens_update l τ' s and lens_view l s a yh Γ, e (a, τ')
    using YH expression_pull_back_eval_abort_canonical[OF lens_valid] by metis
  from this and Γ; φ  e  weak ψ  ξ  θ obtain φ  (lens_view l s,τ') weak θ a
    by (meson stripleE_abort)
  from this and s' = lens_update l τ' s and pull_back_atriple
    show l¯ φ  (s,s') weak l¯ (θ a)
      by simp
  qed
qed

notation slens_embed ("ι")
notation slens_view ("π")
notation slens_proj0 ("ρ0")
notation slens_proj1 ("ρ1")

lemma pull_back_local_relation_disj:
  assumes is_local R φ
  shows σ_0 σ_1 σ_0' v.
      σ_0  σ_1  σ_0  l¯ φ  l¯ R σ_0 (v, σ_0')  σ_0'  σ_1
proof -
  fix σ_0 σ_1 σ_0' v
  assume σ_0  σ_1 and σ_0  l¯ φ  and l¯ R σ_0 (v, σ_0')
  moreover from l¯ R σ_0 (v, σ_0') obtain τ' where
     R (π σ_0) (v, τ') and σ_0' = ρ1 σ_0 + ι τ'
    apply (clarsimp simp add: pull_back_relation_def)
    using slens_update_alt(1) slens_valid by blast
  moreover from σ_0  σ_1 and σ_0  l¯ φ have
    π σ_0  π σ_1 and π σ_0  φ
    apply (simp add: slens_view_local1)
    apply (metis σ_0  l¯ φ asat_def mem_Collect_eq pull_back_assertion_def)
    done
  moreover from this and is_local R φ have τ'  π σ_1
    by (meson calculation(4) is_localE)
  ultimately show σ_0'  σ_1
    by (metis slens_update_alt(1) slens_update_local4 slens_valid)
qed

lemma pull_back_local_relation[slens_pull_back_intros]:
  assumes is_local R φ
  shows is_local (l¯ R) (l¯ φ)
proof (intro is_localI)
  fix σ_0 σ_1 σ_0' v
  assume σ_0  σ_1 and σ_0  l¯ φ  and l¯ R σ_0 (v, σ_0')
  from this show σ_0'  σ_1
    by (metis assms pull_back_local_relation_disj)
next
  fix σ_0 σ_1 σ' v
  assume σ_0  σ_1 and σ_0  l¯ φ and l¯ R (σ_0 + σ_1) (v, σ')
  moreover from this obtain τ' where
    R (π (σ_0 + σ_1)) (v, τ') and lens_update l τ' (σ_0 + σ_1) = σ'
    unfolding pull_back_relation_def by auto
  moreover from this have σ' = ρ1 σ_0 + ρ1 σ_1 + ι τ'
    using slens_valid by (metis calculation(1) slens_complement_additive(2) slens_update_alt(1))
  moreover from R (π (σ_0 + σ_1)) (v, τ') have R (π σ_0 + π σ_1) (v, τ')
    by (simp add: calculation(1) slens_view_local2)
  moreover from calculation obtain σ_0'' where
    R (π σ_0) (v, σ_0'') and τ' = σ_0'' + π σ_1
    using slens_valid is_local R φ unfolding is_local_def
    by (meson is_valid_slensE pull_back_asat_adjoint)
  moreover from calculation have σ_0''  π σ_1
    using slens_valid by (meson assms is_localE pull_back_asat_adjoint slens_view_local1)
  moreover note facts = calculation
  let ?σ_0' = ρ1 σ_0 + ι σ_0''
  from facts have ρ1 σ' = ρ1 ?σ_0' + ρ1 σ_1
    using slens_valid by (metis slens_lens_laws(1) slens_complement_additive(2) slens_complement_cancel_core)
  moreover from facts have σ' = ?σ_0' + σ_1
    using slens_valid by (metis slens_update_alt(1) slens_update_local3)
  moreover have π ?σ_0' = σ_0''
    using slens_valid  by (metis slens_lens_laws(1) slens_update_alt(1))
  moreover from calculation have ?σ_0' = lens_update l σ_0'' σ_0
    using slens_valid  by (metis slens_update_alt(1))
  moreover from calculation facts have l¯ R σ_0 (v, ?σ_0')
    unfolding pull_back_relation_def by force
  ultimately show σ_0'. l¯ R σ_0 (v, σ_0')  σ' = σ_0' + σ_1
    by blast
next
  fix σ_0 σ_1 σ_0' v σ'
  assume σ_0  σ_1
     and σ_0  l¯ φ
     and l¯ R σ_0 (v, σ_0')  σ' = σ_0' + σ_1
  moreover from this have σ_0'  σ_1
    by (meson assms pull_back_local_relation_disj)
  moreover from calculation obtain τ' where R (π σ_0) (v, τ') and σ_0' = lens_update l τ' σ_0
    unfolding pull_back_relation_def by auto
  moreover have σ_0' = ρ1 σ_0 + ι τ'
    using slens_valid by (metis calculation(6) slens_update_alt(1))
  moreover from calculation have π σ' = π σ_0' + π σ_1
    using slens_valid  slens_view_local2 by blast
  moreover from calculation and is_local R φ have R (π σ_0 + π σ_1) (v, τ' + π σ_1)
    using slens_valid by (metis asat_def is_localE mem_Collect_eq pull_back_assertion_def slens_view_local1)
  moreover from calculation have σ' = lens_update l (τ' + π σ_1) (σ_0 + σ_1)
    using slens_valid by (metis slens_lens_laws(1) slens_update_local3 slens_view_local1)
  moreover have π (σ_0 + σ_1) = π σ_0 + π σ_1
    by (simp add: calculation(1) slens_view_local2)
  from this and σ' = lens_update l (τ' + π σ_1) (σ_0 + σ_1) and
    R (π σ_0 + π σ_1) (v, τ' + π σ_1) show l¯ R (σ_0 + σ_1) (v, σ')
    unfolding pull_back_relation_def by auto
qed

lemma pull_back_local_urust:
  assumes urust_is_local y e φ
    shows urust_is_local (l¯ y) (l¯ e) (l¯ φ)
  using assms
  by (clarsimp simp add: pull_back_relation_def
    expression_pull_back_eval_value_canonical[OF lens_valid]
    expression_pull_back_eval_abort_canonical[OF lens_valid]
    expression_pull_back_eval_return_canonical[OF lens_valid]
    elim!: pull_back_local_relation[elim_format])

lemma pull_back_sstriple[slens_pull_back_intros]:
  fixes e :: ('t, 'v, 'r, 'abort, 'i prompt, 'o prompt_output) expression
  assumes T: Γ; φ  e  ψ  ξ  θ
    shows l¯ Γ; l¯ φ  l¯ e  (λv. l¯ (ψ v))  (λr. l¯ (ξ r))  (λr. l¯ (θ r))
  using assms by (clarsimp simp add: sstriple_striple' is_canonical_lifted_striple_context_def
     pull_back_local_urust pull_back_striple pull_back_striple_context_yield_handler)

―‹This is an artifact of the current use of ‹⊥› as the abort-postcondition in function contracts.
 Once we generalize, this should no longer be necessary.›
lemma pull_back_sstriple_bot[slens_pull_back_intros]:
  fixes e :: ('t, 'v, 'r, 'abort, 'i prompt, 'o prompt_output) expression
  assumes T: Γ; φ  e  ψ  ξ  
  shows l¯ Γ; l¯ φ  l¯ e  (λv. l¯ (ψ v))  (λr. l¯ (ξ r))  
proof -
  let ?bot = (λ_. ) :: 'abort abort  't assert
  have eq:  = (λr. (l¯ (?bot r)))
    by (simp add: bot_fun_def pull_back_assertion_false)
  from this assms pull_back_sstriple show ?thesis
    by fastforce
qed

lemma pull_back_sstriple_universal_bot[slens_pull_back_intros]:
  assumes Γ. Γ; φ  e  ψ  ξ  
  shows Γ. Γ; l¯ φ  l¯ e  (λv. l¯ (ψ v))  (λr. l¯ (ξ r))  
proof -
  from assms have striple_context_no_yield; φ  e  ψ  ξ  
    by simp
  from this have l¯ striple_context_no_yield; l¯ φ  l¯ e  (λv. l¯ (ψ v))  (λr. l¯ (ξ r))  
    by (intro pull_back_sstriple_bot; assumption)
  from this have striple_context_no_yield; l¯ φ  l¯ e  (λv. l¯ (ψ v))  (λr. l¯ (ξ r))  
    using lens_valid by force
  from this show Γ. Γ ; l¯ φ  l¯ e  (λv. l¯ (ψ v))  (λr. l¯ (ξ r))  
    using sstriple_yield_handler_no_yield_implies_all by blast
qed

lemma pull_back_spec[slens_pull_back_intros]:
  assumes Γ ; f F 
  shows l¯ Γ; l¯ f F l¯ 
  using assms unfolding satisfies_function_contract_def pull_back_contract_def
  by (clarsimp simp add: pull_back_ucincl function_pull_back_def pull_back_sstriple)

lemma pull_back_spec_universal:
  assumes Γ. Γ; f F 
      and function_contract_abort  = 
  shows Γ. Γ; l¯ f F l¯ 
  using assms unfolding satisfies_function_contract_def pull_back_contract_def
  by (simp add: function_pull_back_def pull_back_sstriple_universal_bot pull_back_ucincl
    slens_pull_back_simps bot_fun_def[symmetric] ucincl_intros)

no_notation slens_embed ("ι")
no_notation slens_view ("π")
no_notation slens_proj0 ("ρ0")
no_notation slens_proj1 ("ρ1")

―‹Give this a name so we can refer to it all simplification and introduction rules even
outside of the locale›
lemmas slens_pull_back_simps_copy = slens_pull_back_simps
lemmas slens_pull_back_intros_copy = slens_pull_back_intros

end

―‹Simplifying to get rid of the wrapper ‹lens l ≡ is_valid_slens l› generated by the locale.›
lemmas slens_pull_back_simps_generic = slens.slens_pull_back_simps_copy[simplified]
lemmas slens_pull_back_intros_generic = slens.slens_pull_back_intros_copy[simplified]

end