Skip to content
24 changes: 24 additions & 0 deletions src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,14 @@ impl<'tcx> Analyzer<'tcx> {
self.def_ids.clone()
}

pub fn mark_uses_seq_concat(&self) {
self.system.borrow_mut().uses_seq_concat = true;
}

pub fn mark_uses_seq_subseq(&self) {
self.system.borrow_mut().uses_seq_subseq = true;
}

pub fn add_clause(&mut self, clause: chc::Clause) {
self.system.borrow_mut().push_clause(clause);
}
Expand Down Expand Up @@ -454,6 +462,22 @@ impl<'tcx> Analyzer<'tcx> {
Some(formula_fn)
}

/// Companion of [`Self::formula_fn_with_args`] for `#[thrust::formula_fn]`
/// bodies whose body yields a model term (rather than a bool formula) — used
/// by `snapshot!{}`. Not cached.
pub fn term_fn_with_args(
&self,
local_def_id: LocalDefId,
generic_args: mir_ty::GenericArgsRef<'tcx>,
) -> Option<annot_fn::TermFn<'tcx>> {
// Reuse the registration set used for `formula_fn_with_args`: any
// `#[thrust::formula_fn]` may serve as either.
self.formula_fns.get(&local_def_id)?;
let translator = annot_fn::AnnotFnTranslator::new(self, local_def_id, generic_args)
.with_def_id_cache(self.def_ids());
Some(translator.to_term_fn())
}

pub fn def_ty_with_args(
&mut self,
def_id: DefId,
Expand Down
72 changes: 72 additions & 0 deletions src/analyze/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,62 @@ pub fn array_model_store_path() -> [Symbol; 3] {
]
}

pub fn seq_model_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_model"),
]
}

pub fn seq_empty_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_empty"),
]
}

pub fn seq_singleton_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_singleton"),
]
}

pub fn seq_len_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_len"),
]
}

pub fn seq_push_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_push"),
]
}

pub fn seq_concat_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_concat"),
]
}

pub fn seq_subsequence_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("seq_subsequence"),
]
}

pub fn exists_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Expand Down Expand Up @@ -169,6 +225,22 @@ pub fn invariant_marker_path() -> [Symbol; 3] {
]
}

pub fn snapshot_marker_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("snapshot_marker"),
]
}

pub fn proof_assert_marker_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("proof_assert_marker"),
]
}

pub fn fn_param_wrapper_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Expand Down
134 changes: 133 additions & 1 deletion src/analyze/annot_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,42 @@ where
}
}

/// The term analogue of [`FormulaFn`]: a `#[thrust::formula_fn]` whose body
/// evaluates to a non-`bool` model term (used by `snapshot!{}`).
#[derive(Debug, Clone)]
pub struct TermFn<'tcx> {
params: IndexVec<rty::FunctionParamIdx, mir_ty::Ty<'tcx>>,
term: chc::Term<rty::FunctionParamIdx>,
}

impl<'a, D> Pretty<'a, D, termcolor::ColorSpec> for &TermFn<'_>
where
D: pretty::DocAllocator<'a, termcolor::ColorSpec>,
D::Doc: Clone,
{
fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> {
allocator
.intersperse(
self.params.iter_enumerated().map(|(idx, ty)| {
idx.pretty(allocator)
.append(": ")
.append(allocator.as_string(ty))
}),
", ",
)
.enclose("|", "|")
.group()
.append(self.term.pretty(allocator))
.group()
}
}

impl<'tcx> TermFn<'tcx> {
pub fn term(&self) -> &chc::Term<rty::FunctionParamIdx> {
&self.term
}
}

impl<'tcx> FormulaFn<'tcx> {
pub fn formula(&self) -> &chc::Formula<rty::FunctionParamIdx> {
&self.formula
Expand Down Expand Up @@ -314,6 +350,25 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
}
}

/// Same shape as [`Self::to_formula_fn`] but for `#[thrust::formula_fn]`
/// bodies whose value is a non-`bool` model term (i.e. `snapshot!{}`
/// closures). The body is interpreted via [`Self::to_term`] rather than
/// [`Self::to_formula`].
pub fn to_term_fn(&self) -> TermFn<'tcx> {
let term = self.to_term(self.body.value);
let params = self
.tcx
.fn_sig(self.local_def_id.to_def_id())
.instantiate(self.tcx, self.generic_args)
.skip_binder()
.inputs()
.to_vec();
TermFn {
params: IndexVec::from_raw(params),
term,
}
}

fn to_formula(&self, hir: &'tcx rustc_hir::Expr<'tcx>) -> chc::Formula<rty::FunctionParamIdx> {
self.to_formula_or_term(hir)
.into_formula()
Expand Down Expand Up @@ -623,9 +678,18 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
FormulaOrTerm::Term(term.tuple_proj(index))
}
ExprKind::Index(array, index, _) => {
let array_ty = self.expr_ty(array);
let array_term = self.to_term(array);
let index_term = self.to_term(index);
FormulaOrTerm::Term(array_term.select(index_term))
let is_seq = array_ty
.ty_adt_def()
.is_some_and(|adt| Some(adt.did()) == self.def_ids.seq_model());
let array_inner = if is_seq {
array_term.tuple_proj(0)
} else {
array_term
};
FormulaOrTerm::Term(array_inner.select(index_term))
}
ExprKind::MethodCall(method, receiver, args, _) => {
if let Some(def_id) = self.typeck.type_dependent_def_id(hir.hir_id) {
Expand All @@ -644,6 +708,56 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
let t = self.to_term(receiver);
return FormulaOrTerm::Term(t);
}
if Some(def_id) == self.def_ids.seq_len() {
assert!(args.is_empty(), "Seq::len does not take any arguments");
let t = self.to_term(receiver);
return FormulaOrTerm::Term(t.tuple_proj(1));
}
if Some(def_id) == self.def_ids.seq_push() {
assert_eq!(args.len(), 1, "Seq::push takes exactly 1 argument");
let t = self.to_term(receiver);
let v = self.to_term(&args[0]);
let arr = t.clone().tuple_proj(0);
let len = t.tuple_proj(1);
let new_arr = arr.store(len.clone(), v);
let new_len = len.add(chc::Term::int(1));
return FormulaOrTerm::Term(chc::Term::tuple(vec![new_arr, new_len]));
}
if Some(def_id) == self.def_ids.seq_concat() {
assert_eq!(args.len(), 1, "Seq::concat takes exactly 1 argument");
self.analyzer.mark_uses_seq_concat();
let t = self.to_term(receiver);
let other = self.to_term(&args[0]);
let a_arr = t.clone().tuple_proj(0);
let a_len = t.tuple_proj(1);
let b_arr = other.clone().tuple_proj(0);
let b_len = other.tuple_proj(1);
// The array half is the recursive SMT-defined
// function `seq_concat_arr_int(sa, sn, ta, tn)`;
// the length half is computed inline so length
// properties remain provable on any solver (the
// SMT obligation never has to mention the array).
let new_arr = chc::Term::App(
chc::Function::SEQ_CONCAT_ARR_INT,
vec![a_arr, a_len.clone(), b_arr, b_len.clone()],
);
let new_len = a_len.add(b_len);
return FormulaOrTerm::Term(chc::Term::tuple(vec![new_arr, new_len]));
}
if Some(def_id) == self.def_ids.seq_subsequence() {
assert_eq!(args.len(), 2, "Seq::subsequence takes exactly 2 arguments");
self.analyzer.mark_uses_seq_subseq();
let t = self.to_term(receiver);
let l = self.to_term(&args[0]);
let r = self.to_term(&args[1]);
let arr = t.tuple_proj(0);
let new_arr = chc::Term::App(
chc::Function::SEQ_SUBSEQ_ARR_INT,
vec![arr, l.clone(), r.clone()],
);
let new_len = r.sub(l);
return FormulaOrTerm::Term(chc::Term::tuple(vec![new_arr, new_len]));
}
}
unimplemented!("unsupported method call in formula: {:?}", method)
}
Expand Down Expand Up @@ -719,6 +833,24 @@ impl<'a, 'tcx> AnnotFnTranslator<'a, 'tcx> {
let t = self.to_term(&args[0]);
return FormulaOrTerm::Term(chc::Term::box_(t));
}
if Some(def_id) == self.def_ids.seq_empty() {
assert!(args.is_empty(), "Seq::empty does not take any arguments");
let arr = chc::Term::App(chc::Function::SEQ_DEFAULT_ARR_INT, vec![]);
return FormulaOrTerm::Term(chc::Term::tuple(vec![
arr,
chc::Term::int(0),
]));
}
if Some(def_id) == self.def_ids.seq_singleton() {
assert_eq!(args.len(), 1, "Seq::singleton takes exactly 1 argument");
let v = self.to_term(&args[0]);
let arr = chc::Term::App(chc::Function::SEQ_DEFAULT_ARR_INT, vec![]);
let new_arr = arr.store(chc::Term::int(0), v);
return FormulaOrTerm::Term(chc::Term::tuple(vec![
new_arr,
chc::Term::int(1),
]));
}
if let rustc_hir::def::DefKind::Ctor(ctor_of, _) = def_kind {
let terms = args.iter().map(|e| self.to_term(e)).collect();
match ctor_of {
Expand Down
Loading