Theory Misc.Vector

(* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
   SPDX-License-Identifier: MIT *)

(*<*)
theory Vector
  imports "HOL-Library.Word" Data_Structures.BraunTreeAdditional ListAdditional
begin
(*>*)

section‹A bounded-length vector type›

text‹Here we develop a small utility library of bounded vectors.

We use Braun trees to represent the vectors. This gives reasonably
efficient log-time implementations of all the major operations.›

subsection‹Supporting lemmas›

lemma lookup1_take_list_fast:
  assumes braun a
      and m  size a
      and n < m
    shows lookup1 a (Suc n)  set (take m (list_fast a))
using assms  by (metis (mono_tags, lifting) basic_trans_rules(22) in_set_conv_nth length_take
  list_fast_correct min_less_iff_conj nth_list_lookup1 nth_take semiring_norm(174) size_list)

lemma set_image_take_nth:
  assumes i < length ls
      and i < n
    shows f (ls ! i)  f ` set (take n ls)
using assms by (force intro: imageI simp add: in_set_conv_nth)

lemma set_image_take_nth_elim:
  assumes f (ls ! i)  f ` set (take n ls)
      and i < length ls i < n
    shows R
using assms set_image_take_nth by force

lemma braun_nonempty_butlast:
  assumes braun a size a = Suc n
  shows butlast (braun_list a) @ [lookup1 a (size a)] = braun_list a 
proof -
  have last (braun_list a) = braun_list a ! n
    using assms by (metis last_conv_nth list_Nil_iff size_list Zero_not_Suc diff_Suc_1 eq_size_0)
  with nth_list_lookup1[of a n] assms show ?thesis
    by (auto simp: snoc_eq_iff_butlast list_Nil_iff)
qed

subsection‹The type itself›

definition is_vector :: 'a tree × nat  nat  bool where
  is_vector x n  braun (fst x)  size (fst x) = snd x  snd x  n

typedef (overloaded) ('a, 'b::{len}) vector =
  { x::('a tree × nat). is_vector x LENGTH('b) }
proof -
  have (Leaf, 0)  {x. is_vector x LENGTH('b)}
    by (clarsimp simp add: is_vector_def)
  from this show ?thesis
    by blast
qed

(*<*)
setup_lifting vector.type_definition_vector

text‹The following is needed to explain to the code-generator how equalities at type ‹vector›
should be treated (basically, use the underlying equality on braun trees):›
instantiation vector :: (equal, len) equal
begin

lift_definition equal_vector :: ('a, 'b::len) vector  ('a, 'b) vector  bool
  is (=) .

instance
  by standard (simp add: equal_vector_def Rep_vector_inject)

end
(*>*)

subsection‹Lifted operations on vectors›

lift_definition vector_len :: ('a, 'b::{len}) vector  nat is snd .

lift_definition vector_map :: ('a  'c)  ('a, 'b::{len}) vector  ('c, 'b) vector is
  λf (t,n). (tree_map f t, n)
by (auto simp add: is_vector_def split: prod.splits)

lift_definition vector_nth :: ('a, 'b::{len}) vector  nat  'a is λ(t, n) i. lookup1 t (i+1) .

lift_definition vector_update :: ('a, 'b::{len}) vector  nat  'a  ('a, 'b) vector is
  λ(t, n) i x. if i < n then (update1 (Suc i) x t, n) else (t, n)
by (auto simp add: is_vector_def braun_update1 size_update1)

lift_definition vector_new :: ('a, 'b::{len}) vector is
  (Leaf, 0)
by (simp add: is_vector_def)

lift_definition vector_push_raw :: 'a  ('a, 'b::{len}) vector  ('a, 'b) vector is
  λv (t,n). if n < LENGTH('b) then (update1 (Suc n) v t, Suc n) else (t,n)
by (auto simp add: is_vector_def braun_add_hi size_add_hi)

definition vector_push :: 'a  ('a, 'b::{len}) vector  ('a,'b) vector option where
  vector_push  λv xs. if vector_len xs < LENGTH('b) then Some (vector_push_raw v xs) else None

lift_definition vector_last :: ('a, 'b::{len}) vector  'a is λ(t,n). lookup1 t n .

lift_definition vector_pop_raw :: ('a,'b::{len}) vector  ('a, 'b) vector is
  λ(t, n). if n > 0 then (del_hi n t, n-1) else (t,n)
by (auto split: prod.splits simp add: is_vector_def braun_del_hi list_del_hi simp flip: size_list)

definition vector_pop :: ('a, 'b::{len}) vector  ('a × ('a, 'b) vector) option where
  vector_pop xs  if vector_len xs > 0 then Some (vector_last xs, vector_pop_raw xs) else None

lift_definition vector_to_list :: ('a, 'b::{len}) vector  'a list is λ(t, n). list_fast t .

lift_definition vector_of_list :: 'a list  ('a, 'b::{len}) vector is
  λl. let v = brauns1 l;
           n = size_fast v
        in if n  LENGTH('b) then (v, n) else (Leaf, 0)
by (clarsimp simp add:is_vector_def  Let_def) (metis brauns1_correct size_fast)

lemma vector_extI:
    fixes xs ys :: ('a, 'l::{len}) vector
  assumes vector_len xs = vector_len ys
      and i. i < vector_len xs  vector_nth xs i = vector_nth ys i
    shows xs = ys
using assms by transfer (auto split: prod.splits simp add: is_vector_def braun_tree_ext)

lemma list_of_vector_of_list:
    fixes v :: ('a, 'b::{len}) vector
  assumes v = vector_of_list l
      and length l  LENGTH('b)
    shows vector_to_list v = l
using assms
  apply transfer
  apply (metis (mono_tags, lifting) brauns1_correct case_prod_conv list_fast_correct size_fast size_list)
  done

lemma vector_new_list [simp]:
  shows vector_to_list vector_new = []
by transfer (simp add: list_fast_correct)

lemma vector_len_list [simp]:
  shows length (vector_to_list xs) = vector_len xs
by transfer (auto split: prod.splits simp add: list_fast_correct is_vector_def size_list)

lemma vector_nth_list [simp]:
  assumes i < vector_len xs
    shows vector_to_list xs ! i = vector_nth xs i
using assms by transfer (auto simp add: is_vector_def list_fast_correct nth_list_lookup1)

lemma vector_push_list:
    fixes l :: ('a, 'l::{len}) vector
  assumes vector_push x l = Some l'
    shows vector_to_list l' = vector_to_list l @ [x]
using assms unfolding vector_push_def by transfer (auto split: if_splits simp add: is_vector_def
  list_fast_correct list_add_hi)

lemma vector_pop_list:
  fixes l :: ('a, 'l::{len}) vector
  assumes vector_pop l = Some (x, l')
  shows vector_to_list l = vector_to_list l' @ [x]
  using assms
  unfolding vector_pop_def
  apply transfer
  apply (auto simp add:  is_vector_def list_fast_correct list_del_hi dest: braun_nonempty_butlast split: nat_diff_split_asm prod.splits)
  done

lemma vector_len_bound [simp]:
  fixes xs :: ('a, 'l::{len}) vector
  shows vector_len xs  LENGTH('l)
by transfer (simp add: is_vector_def)

lemma vector_len_vector_of_list:
  assumes length xs  LENGTH('l::{len})
    shows vector_len ((vector_of_list xs)::('a, 'l) vector) = length xs
  by (metis assms list_of_vector_of_list vector_len_list)

lemma vector_new_len [simp]:
  shows vector_len vector_new = 0
by transfer simp

lemma vector_push_len [simp]:
    fixes xs :: ('a, 'l::{len}) vector
  assumes vector_push a xs = Some xs'
    shows vector_len xs' = vector_len xs + 1
using assms unfolding vector_push_def by transfer (auto split: if_splits)

lemma vector_pop_len [simp]:
    fixes xs :: ('a, 'l::{len}) vector
  assumes vector_pop xs = Some (a, xs')
    shows vector_len xs = vector_len xs' + 1
using assms unfolding vector_pop_def by transfer (auto split: if_splits)

lemma vector_update_len [simp]:
  fixes xs :: ('a, 'l::{len}) vector
  shows vector_len (vector_update xs i v) = vector_len xs
by (transfer, auto)

lemma vector_to_list_Nil_vector_len:
  shows vector_to_list vs  []  0 < vector_len vs
  by (metis length_greater_0_conv vector_len_list)

lemma vector_nth_vector_of_list:
  assumes i < length xs
      and length xs  LENGTH('l::{len})
    shows vector_nth ((vector_of_list xs::('a, 'l) vector)) i = xs ! i
  using assms
  by (metis list_of_vector_of_list vector_len_list vector_nth_list)

lemma vector_map_nth [simp]:
    fixes xs :: ('a, 'l::{len}) vector
  assumes i < vector_len xs
    shows vector_nth (vector_map f xs) i = f (vector_nth xs i)
using assms by transfer (auto split: prod.splits simp add: is_vector_def tree_map_lookup1)

lemma vector_map_len[simp]:
  shows vector_len (vector_map f v) = vector_len v
  by transfer auto

lemma vector_map_comp:
  shows vector_map f  vector_map g = vector_map (f  g)
    and vector_map f (vector_map g v) = vector_map (f  g) v
  by (intro ext vector_extI; simp)+

lemma vector_map_id[simp]:
  shows vector_map id = id
    and vector_map (λx. x) v = v
  by (intro ext vector_extI; simp)+

lemma vector_nth_update [simp]:
    fixes l :: ('a, 'l::{len}) vector
  assumes i < vector_len l
      and j < vector_len l
    shows vector_nth (vector_update l i v) j = (if i = j then v else vector_nth l j)
using assms by transfer (auto simp add: lookup1_update1 is_vector_def)


lemma vector_push_nth:
    fixes l :: ('a, 'l::{len}) vector
  assumes i  vector_len l
      and vector_push x l = Some l'
    shows vector_nth l' i = (if i < vector_len l then vector_nth l i else x)
proof -
  have lookup1 (update1 (Suc (size a)) x a) (Suc i) = lookup1 a (Suc i)
    if braun a and braun (update1 (Suc (size a)) x a) and i < size a
    for a 
    using that by (metis Suc_eq_plus1 less_SucI list_add_hi nth_append nth_list_lookup1 size_add_hi size_list)
  moreover have lookup1 (update1 (Suc (size a)) x a) (Suc (size a)) = x
    if braun a braun (update1 (Suc (size a)) x a)
      and size (update1 (Suc (size a)) x a) = Suc (size a)
    for a
    using that braun_nonempty_butlast list_add_hi by fastforce
  ultimately show ?thesis
    using assms unfolding vector_push_def
    by transfer (auto simp add: is_vector_def split: if_splits)
qed

lemma vector_push_end [simp]:
    fixes l :: ('a, 'l::{len}) vector
  assumes vector_push x l = Some l'
    shows vector_nth l' (vector_len l) = x
using assms vector_push_nth by force

lemma vector_push_nil [simp]:
  fixes l :: ('a, 'l::{len}) vector
  shows vector_push x l = None  vector_len l = LENGTH('l)
unfolding vector_push_def by transfer (auto split: prod.splits simp add: is_vector_def)

lemma vector_pop_nil [simp]:
  fixes l :: ('a, 'l::{len}) vector
  shows vector_pop l = None  vector_len l = 0
unfolding vector_pop_def by transfer (auto split: if_splits simp add: is_vector_def)

―‹NOTE: This lemma is not used at present, but seems worth keeping.›
lemma vector_pop_nth:
    notes vector_nth_list [simp del]
      and vector_len_list [simp del]
    fixes l :: ('a, 'l::{len}) vector
  assumes i < vector_len l
      and vector_pop l = Some (x, l')
    shows vector_nth l i = (if i < vector_len l' then vector_nth l' i else x)
using assms vector_pop_list[of l x l'] by (auto simp add: nth_append simp flip: vector_nth_list
  vector_len_list)

lemma vector_update_overwrite [simp]:
  assumes i < vector_len ls
    shows vector_update (vector_update ls i v) i w = vector_update ls i w
using assms by (intro vector_extI) auto

―‹NOTE: This lemma is not used at present, but seems worth keeping.›
lemma vector_update_swap:
  assumes i  j
      and i < vector_len ls
      and j < vector_len ls
    shows vector_update (vector_update ls i v) j w = vector_update (vector_update ls j w) i v
using assms by (intro vector_extI) auto

definition vector_nth_opt :: nat  ('a, 'l::{len}) vector  'a option where
  vector_nth_opt n v  if n < vector_len v then Some (vector_nth v n) else None

lemma vector_nth_opt_spec:
    fixes xs :: ('a, 'l::{len}) vector
  assumes n < vector_len xs
    shows vector_nth_opt n xs = Some (vector_nth xs n)
using assms by (simp add: vector_nth_opt_def)

lemma vector_nth_opt_spec2:
    fixes xs :: ('a, 'l::{len}) vector
  assumes vector_nth_opt n xs = Some v
    shows vector_nth xs n = v and n < vector_len xs
using assms by (auto simp add: vector_nth_opt_def split: if_splits)

lemma vector_nth_optE:
    fixes xs :: ('a, 'l::{len}) vector
  assumes vector_nth_opt n xs = Some v
      and vector_nth xs n = v  n < vector_len xs  R
    shows R
using assms vector_nth_opt_spec2 by metis

lemma vector_nth_optI [intro]:
    fixes xs :: ('a, 'l::{len}) vector
  assumes n < vector_len xs
      and vector_nth xs n = v
    shows vector_nth_opt n xs = Some v
using assms by (auto simp add: vector_nth_opt_spec)

definition vector_over_nth :: nat  ('a  'a)  ('a, 'l::{len}) vector  ('a, 'l) vector where
  vector_over_nth i f v  if i < vector_len v then vector_update v i (f (vector_nth v i)) else v

lemma vector_over_nth_list_update:
  shows vector_over_nth n f xs = (if n < vector_len xs then vector_update xs n (f (vector_nth xs n)) else xs)
by (simp add: vector_over_nth_def)

―‹NOTE: This lemma is not used at present, but seems worth keeping.›
lemma vector_over_nth_list_update':
  shows vector_over_nth = (λn f xs. if n < vector_len xs then vector_update xs n (f (vector_nth xs n)) else xs)
using vector_over_nth_list_update by blast

lemma vector_over_nth_is_valid:
  shows vector_nth_opt n (vector_over_nth n f l) = map_option f (vector_nth_opt n l)
    and map_option f (vector_nth_opt n l) = vector_nth_opt n l  vector_over_nth n f l = l
    and vector_over_nth n f (vector_over_nth n g l) = vector_over_nth n (λx. f (g x)) l
by (auto simp add: vector_over_nth_def vector_nth_opt_def intro: vector_extI)

lemma set_take_vector_to_list:
    fixes vs :: ('a, 'b::{len}) vector
  assumes n. n < m  vector_nth vs n = x
      and m  vector_len vs
    shows x  set (take m (vector_to_list vs))
using assms by transfer (auto simp add: is_vector_def intro!: lookup1_take_list_fast)

lemma set_image_vector_take_vector_nth_elim':
  assumes f (vector_nth ls i)  f ` set (take n (vector_to_list ls))
      and i < vector_len ls
      and i < n
    shows R
using assms
  by (metis set_image_take_nth vector_len_list vector_nth_list)

(*<*)
end
(*>*)