Theory Misc.Array

(* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
   SPDX-License-Identifier: MIT *)

(*<*)
theory Array
  imports "HOL-Library.Word" Data_Structures.BraunTreeAdditional ListAdditional
begin
(*>*)

section‹A fixed-length array type›

subsection‹The type itself›

definition is_array :: 'a tree  nat  bool where
  is_array t n  braun t  size t = n

typedef (overloaded) ('a, 'b::{len}) array = { t::'a tree . is_array t LENGTH('b) }
proof -
  have braun_of undefined LENGTH('b)  {t. is_array t LENGTH('b)}
    by (metis braun_braun_of is_array_def length_replicate list_braun_of mem_Collect_eq size_list)
  from this show ?thesis
    by blast
qed

(*<*)
setup_lifting array.type_definition_array

text‹This is needed to explain to the code-generator how equalities at type ‹array› should be
handled (basically, use the underlying equality on trees):›
instantiation array :: (equal, len) equal
begin

lift_definition equal_array :: ('a, 'b::len) array  ('a, 'b) array  bool is (=) .

instance
  by standard (simp add: equal_array_def Rep_array_inject)

end
(*>*)

subsection ‹Lifted operations on the type›

abbreviation array_len :: ('a,'l::{len}) array  nat where
  array_len xs  LENGTH('l)

lift_definition array_map :: ('a  'c)  ('a, 'b::{len}) array  ('c, 'b::{len}) array
  is tree_map by (simp add: is_array_def tree_map_braun)

lift_definition array_nth :: ('a, 'b::{len}) array  nat  'a
  is λt i. lookup1 t (i+1) .

lift_definition array_update :: ('a, 'b::{len}) array  nat  'a  ('a, 'b::{len}) array
  is λt i v. if i < LENGTH('b) then update1 (i+1) v t else t
  by (auto simp add: is_array_def braun_update1 size_update1)

lift_definition array_constant :: 'a  ('a, 'l::{len}) array
  is λv. braun_of v LENGTH('l)
proof -
  fix a :: 'a
  have size (braun_of a LENGTH('l)) = LENGTH('l)
    using length_replicate size_list list_braun_of by metis
  from this show is_array (braun_of a LENGTH('l)) LENGTH('l)
    by (clarsimp simp add: is_array_def braun_braun_of)
qed

lift_definition array_of_list :: 'a list  ('a, 'b::len) Array.array
  is λls. let a = brauns1 ls;
               n = size_fast a
            in if n = LENGTH('b) then a else braun_of undefined LENGTH('b)
proof -
  fix list :: 'a list
  {
    assume size (brauns1 list)  LENGTH('b)
    from this have size (braun_of undefined LENGTH('b)) = LENGTH('b)
      using length_replicate size_list list_braun_of by metis
  }
  from this show ?thesis list
    by (clarsimp simp add: is_array_def brauns1_correct braun_braun_of size_fast Let_def)
qed

lift_definition array_to_list :: ('a, 'b::len) array  'a list
  is list_fast .

definition array_splice :: nat  ('a, 'l0::{len}) array  ('a, 'l1::{len}) array  ('a, 'l1) array where
  array_splice n xs ys  foldl (λa i. array_update a i (array_nth xs i)) ys [0..<n]

lift_definition array_resize :: ('a, 'l0::{len}) array  'a  ('a, 'l1::len) array
  is λxs default. if LENGTH('l0) < LENGTH('l1) then
             adds (replicate (LENGTH('l1) - LENGTH('l0)) default) (LENGTH('l0)) xs
           else
             braun_take xs LENGTH('l1)
by (auto simp add: is_array_def size_braun_adds braun_braun_take size_braun_take)

lemma array_to_list_nth [simp]:
  assumes i < array_len xs
    shows array_to_list xs ! i = array_nth xs i
using assms by (transfer, clarsimp simp add: is_array_def list_fast_correct nth_list_lookup1)

lemma array_to_list_length [simp]:
  shows length (array_to_list xs) = array_len xs
by (transfer, clarsimp simp add: is_array_def list_fast_correct size_list)

lemma array_to_list_update [simp]:
  shows array_to_list (array_update l i v) = list_update (array_to_list l) i v
  by (transfer, simp add: is_array_def list_fast_correct braun_update1 list_update1 size_list)

lemma array_map_nth [simp]:
  assumes i < array_len xs
    shows array_nth (array_map f xs) i = f (array_nth xs i)
using assms by transfer (auto simp add: is_array_def tree_map_lookup1)

lemma list_to_array_nth [simp]:
  assumes length ls = LENGTH('l::{len})
      and i < LENGTH('l)
    shows array_nth (array_of_list ls :: ('a, 'l) array) i = ls ! i
using assms proof transfer
     fix ls :: 'a list
     and i
  assume length ls = LENGTH('l)
     and i < LENGTH('l)
  moreover {
    assume size_fast (brauns1 ls) = LENGTH('l)
    from this calculation have lookup1 (brauns1 ls) (Suc i) = ls ! i
      by (metis add.commute brauns1_correct nth_list_lookup1 plus_1_eq_Suc size_fast)
  } moreover {
    assume size_fast (brauns1 ls)  LENGTH('l)
    from this calculation have lookup1 (braun_of undefined LENGTH('l)) (Suc i) = ls ! i
      by (metis brauns1_correct size_fast size_list)
  }
  ultimately show lookup1 (
    let a = brauns1 ls;
        n = size_fast a
     in if n = LENGTH('l) then a else braun_of undefined LENGTH('l)) (i + 1) = ls ! i
    by (clarsimp simp add: Let_def)
qed

lemma array_nth_update [simp]:
  assumes j < array_len l
    shows array_nth (array_update l i v) j = (if i = j then v else array_nth l j)
using assms by transfer (auto simp add: is_array_def lookup1_update1)

lemma array_splice_nth [simp]:
  assumes i < array_len ys
    shows array_nth (array_splice n xs ys) i = (if i < n then array_nth xs i else array_nth ys i)
using assms by (induction n, auto simp add: array_splice_def)

lemma array_extI:
    fixes xs ys :: ('a, 'l::{len}) array
  assumes i. i < LENGTH('l)  array_nth xs i = array_nth ys i
    shows xs = ys
using assms by transfer (auto simp add: is_array_def intro: braun_tree_ext)

lemma array_to_list_extI:
    fixes bs cs :: ('a, 'b::{len}) array
  assumes array_to_list bs = array_to_list cs
    shows bs = cs
using assms by transfer (metis braun_tree_ext is_array_def list_fast_correct nth_list_lookup1
  semiring_norm(174))

lemma array_update_nth [simp]:
  shows array_update arr j (array_nth arr j) = arr
by (simp add: array_extI)

lemma resize_array_nth:
    fixes xs :: ('a, 'l0::{len}) array
      and ys :: ('a, 'l1::{len}) array
  assumes array_resize xs d = ys
      and i < LENGTH('l1)
    shows array_nth ys i = (if i < LENGTH('l0) then array_nth xs i else d)
using assms by transfer (auto simp add: is_array_def braun_take_lookup1 lookup1_adds
  nth_list_lookup1 size_list nth_append split: if_splits)

lemma list_to_array_to_list:
    fixes a :: ('a, 'b::{len}) array
  assumes a = array_of_list l
      and length l = LENGTH('b)
    shows array_to_list a = l
using assms proof transfer
     fix a
     and l :: 'a list
  assume is_array a LENGTH('b)
     and a = (let a = brauns1 l; n = size_fast a in if n = LENGTH('b) then a else braun_of undefined LENGTH('b))
     and length l = LENGTH('b)
  moreover {
      note calculation_thus_far = calculation
    assume size_fast (brauns1 l) = LENGTH('b)
    moreover from this calculation_thus_far have a = brauns1 l and length l = LENGTH('b)
      by (auto simp add: Let_def)
    moreover from calculation calculation_thus_far have size_fast (brauns1 l) = LENGTH('b)
      by (auto simp add: is_array_def)
    moreover from calculation have list_fast (brauns1 l) = l
      by (simp add: brauns1_correct list_fast_correct)
    ultimately have list_fast a = l
      by (clarsimp simp add: list_fast_correct)
  } moreover {
    assume size_fast (brauns1 l)  LENGTH('b)
    from this calculation have list_fast a = l
      by (clarsimp simp add: Let_def) (metis brauns1_correct size_fast size_list)
  }
  ultimately show list_fast a = l
    by blast
qed

lemma array_splice_Suc:
  shows array_splice (Suc n) xs ys = array_update (array_splice n xs ys) n (array_nth xs n)
by (clarsimp intro!: array_extI simp add: array_splice_def)

lemma array_splice_0 [simp]:
  shows array_splice 0 xs = id
    and array_splice 0 xs ys = ys
by (auto intro!: array_extI simp add: array_splice_def)

lemma array_constant_nth [simp]:
  assumes i < LENGTH('l::{len})
    shows array_nth (array_constant v :: ('a, 'l :: {len}) array) i = v
using assms proof transfer
     fix i
     and v :: 'a
  assume i < LENGTH('l)
  from this show lookup1 (braun_of v LENGTH('l)) (i + 1) = v
    using list_braun_of length_replicate nth_list_lookup1 nth_replicate by (metis Suc_eq_plus1
      braun_braun_of size_list)
qed

lemma array_update_overwrite [simp]:
  shows array_update (array_update ls i v) i w = array_update ls i w
by (intro array_extI) auto

lemma array_update_swap:
  assumes i  j
    shows array_update (array_update ls i v) j w = array_update (array_update ls j w) i v
using assms by (intro array_extI) auto

lemma array_splice_update0:
    fixes xs :: ('a, 'l0::{len}) array
      and ys :: ('a, 'l1::{len}) array
  assumes n  LENGTH('l0)
      and n  LENGTH('l1)
      and i < n
    shows array_update (array_splice n xs ys) i v = array_splice n (array_update xs i v) ys
using assms by (auto intro: array_extI)

lemma array_splice_update1:
    fixes xs :: ('a, 'l0::{len}) array
      and ys :: ('a, 'l1::{len}) array
  assumes n  LENGTH('l0)
      and n  LENGTH('l1)
      and i < LENGTH('l1)
      and i  n
    shows array_update (array_splice n xs ys) i v = array_splice n xs (array_update ys i v)
using assms by (auto intro: array_extI)

definition array_over_nth :: nat  ('a  'a)  ('a, 'l::{len}) array  ('a, 'l) array where
  array_over_nth i f v  if i < LENGTH('l) then array_update v i (f (array_nth v i)) else v

lemma array_over_nth_list_update:
  fixes xs :: ('a, 'l::{len}) array
  shows array_over_nth n f xs = (if n < array_len xs then array_update xs n (f (array_nth xs n)) else xs)
by (auto simp add: array_over_nth_def)

lemma array_over_nth_list_update':
  fixes xs :: ('a, 'l::{len}) array
  shows array_over_nth = (λn f xs. if n < array_len xs then array_update xs n (f (array_nth xs n)) else xs)
by (auto simp add: array_over_nth_def intro!: ext)

lemma array_of_list_to_array [simp]:
  shows array_of_list (array_to_list a) = a
proof transfer
     fix a :: 'a tree
  assume is_array a LENGTH('b::{len})
  moreover {
      note calculation_thus_far = calculation
    assume size_fast (brauns1 (list_fast a)) = LENGTH('b)
    moreover from this calculation_thus_far have braun a and size a = LENGTH('b)
      by (auto simp add: is_array_def)
    moreover from calculation have brauns1 (braun_list a) = a
      by (metis Suc_eq_plus1 braun_tree_ext brauns1_correct list_fast_correct nth_list_lookup1 size_fast)
    ultimately have brauns1 (list_fast a) = a
      by (clarsimp simp add: list_fast_correct)
  } moreover {
    assume size_fast (brauns1 (list_fast a))  LENGTH('b)
    from this calculation have braun_of undefined LENGTH('b) = a
      by (metis brauns1_correct is_array_def list_fast_correct size_fast size_list)
  }
  ultimately show (let a = brauns1 (list_fast a); n = size_fast a in if n = LENGTH('b) then a else braun_of undefined LENGTH('b)) = a
    by (clarsimp simp add: Let_def)
qed

lemma array_of_list_list_update [simp]:
  assumes LENGTH('b::{len}) = length xs
    shows (array_of_list (xs[i := e])::('a, 'b) array) = array_update (array_of_list xs) i e
using assms by (metis array_of_list_to_array array_to_list_update list_to_array_to_list)

(*<*)
end
(*>*)