Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 104 additions & 12 deletions src/plugins/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,62 @@ impl SqlPlugin {
}
}

/// Build a qualifier→real-table map for the FROM/JOIN clauses of a query.
/// Each real table maps to itself, and each alias (e.g. `u` in
/// `FROM users u`) maps to its real table, so later passes can resolve a
/// qualified reference like `u.id` to the `users` table.
fn extract_table_aliases(statement: &Statement) -> Vec<(String, String)> {
let mut aliases = Vec::new();
match statement {
Statement::Query(query) => {
if let SetExpr::Select(select) = query.body.as_ref() {
for twj in &select.from {
Self::collect_aliases_from_join(twj, &mut aliases);
}
}
}
Statement::Delete(delete) => match &delete.from {
sqlparser::ast::FromTable::WithFromKeyword(twjs)
| sqlparser::ast::FromTable::WithoutKeyword(twjs) => {
for twj in twjs {
Self::collect_aliases_from_join(twj, &mut aliases);
}
}
},
Statement::Update { table, .. } => Self::collect_aliases_from_join(table, &mut aliases),
_ => {}
}
aliases
}

fn collect_aliases_from_join(twj: &TableWithJoins, aliases: &mut Vec<(String, String)>) {
Self::collect_alias_from_factor(&twj.relation, aliases);
for join in &twj.joins {
Self::collect_alias_from_factor(&join.relation, aliases);
}
}

fn collect_alias_from_factor(factor: &TableFactor, aliases: &mut Vec<(String, String)>) {
if let TableFactor::Table { name, alias, .. } = factor {
let real = name.to_string().to_lowercase();
aliases.push((real.clone(), real.clone()));
if let Some(alias) = alias {
aliases.push((alias.name.value.to_lowercase(), real));
}
}
}

/// Resolve a table qualifier (a real table name or an alias) to its real
/// table name. An unknown qualifier falls back to itself, so a later schema
/// lookup still flags it rather than silently passing.
fn resolve_qualifier(aliases: &[(String, String)], qualifier: &str) -> String {
aliases
.iter()
.find(|(q, _)| q == qualifier)
.map(|(_, t)| t.clone())
.unwrap_or_else(|| qualifier.to_string())
}

/// Extract all column references from a statement.
fn extract_column_refs(statement: &Statement) -> Vec<(Option<String>, String)> {
let mut cols = Vec::new();
Expand Down Expand Up @@ -172,12 +228,13 @@ impl SqlPlugin {
right: &Expr,
schema: &Schema,
tables_in_query: &[String],
aliases: &[(String, String)],
) -> Vec<TypeIssue> {
let mut issues = Vec::new();

// Get types of left and right if we can resolve them
let left_type = Self::infer_expr_type(left, schema, tables_in_query);
let right_type = Self::infer_expr_type(right, schema, tables_in_query);
let left_type = Self::infer_expr_type(left, schema, tables_in_query, aliases);
let right_type = Self::infer_expr_type(right, schema, tables_in_query, aliases);

if let (Some(lt), Some(rt)) = (&left_type, &right_type) {
let lt_cat = type_category(lt);
Expand Down Expand Up @@ -225,7 +282,12 @@ impl SqlPlugin {
}

/// Attempt to infer the SQL type of an expression given the schema.
fn infer_expr_type(expr: &Expr, schema: &Schema, tables_in_query: &[String]) -> Option<String> {
fn infer_expr_type(
expr: &Expr,
schema: &Schema,
tables_in_query: &[String],
aliases: &[(String, String)],
) -> Option<String> {
match expr {
Expr::Identifier(ident) => {
let col_name = ident.value.to_lowercase();
Expand All @@ -240,7 +302,7 @@ impl SqlPlugin {
None
}
Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
let table_name = parts[0].value.to_lowercase();
let table_name = Self::resolve_qualifier(aliases, &parts[0].value.to_lowercase());
let col_name = parts[1].value.to_lowercase();
if let Some(table) = schema.tables.iter().find(|t| t.name == table_name)
&& let Some(col) = table.columns.iter().find(|c| c.name == col_name)
Expand Down Expand Up @@ -323,6 +385,7 @@ impl QueryLanguagePlugin for SqlPlugin {
for stmt in &statements {
// Check table references
let table_refs = Self::extract_table_refs(stmt);
let aliases = Self::extract_table_aliases(stmt);
for table_name in &table_refs {
if !schema.tables.iter().any(|t| t.name == *table_name) {
issues.push(SchemaIssue {
Expand All @@ -334,10 +397,12 @@ impl QueryLanguagePlugin for SqlPlugin {
// Check column references
let col_refs = Self::extract_column_refs(stmt);
for (table_qualifier, col_name) in &col_refs {
let tables_to_check: Vec<&str> = if let Some(tq) = table_qualifier {
vec![tq.as_str()]
// Resolve an alias qualifier (`u`) to its real table (`users`);
// an unqualified column is checked against every table in scope.
let tables_to_check: Vec<String> = if let Some(tq) = table_qualifier {
vec![Self::resolve_qualifier(&aliases, tq)]
} else {
table_refs.iter().map(|s| s.as_str()).collect()
table_refs.clone()
};

let found = tables_to_check.iter().any(|tn| {
Expand Down Expand Up @@ -378,13 +443,14 @@ impl QueryLanguagePlugin for SqlPlugin {

for stmt in &statements {
let table_refs = Self::extract_table_refs(stmt);
let aliases = Self::extract_table_aliases(stmt);

// Check WHERE clause binary operations for type compatibility
if let Statement::Query(query) = stmt
&& let SetExpr::Select(select) = query.body.as_ref()
&& let Some(ref selection) = select.selection
{
Self::check_expr_types(selection, schema, &table_refs, &mut issues);
Self::check_expr_types(selection, schema, &table_refs, &aliases, &mut issues);
}
}

Expand All @@ -398,6 +464,7 @@ impl QueryLanguagePlugin for SqlPlugin {

for stmt in &statements {
let table_refs = Self::extract_table_refs(stmt);
let aliases = Self::extract_table_aliases(stmt);

// Check if SELECT includes nullable columns without COALESCE or IS NULL handling
if let Statement::Query(q) = stmt
Expand Down Expand Up @@ -428,6 +495,29 @@ impl QueryLanguagePlugin for SqlPlugin {
}
}
}
// Alias-qualified projection (e.g. `u.email` in `FROM users u`):
// resolve the qualifier so nullability is still checked.
SelectItem::UnnamedExpr(Expr::CompoundIdentifier(parts))
| SelectItem::ExprWithAlias {
expr: Expr::CompoundIdentifier(parts),
..
} if parts.len() == 2 => {
let table_name =
Self::resolve_qualifier(&aliases, &parts[0].value.to_lowercase());
let col_name = parts[1].value.to_lowercase();
if let Some(table) = schema.tables.iter().find(|t| t.name == table_name)
&& let Some(col) = table.columns.iter().find(|c| c.name == col_name)
&& col.nullable
{
issues.push(NullIssue {
message: format!(
"Nullable column '{}' selected without COALESCE or null handling",
col_name
),
column: col_name.clone(),
});
}
}
_ => {}
}
}
Expand All @@ -444,16 +534,18 @@ impl SqlPlugin {
expr: &Expr,
schema: &Schema,
tables: &[String],
aliases: &[(String, String)],
issues: &mut Vec<TypeIssue>,
) {
match expr {
Expr::BinaryOp { left, op, right } => {
let new_issues = Self::check_binary_op_types(op, left, right, schema, tables);
let new_issues =
Self::check_binary_op_types(op, left, right, schema, tables, aliases);
issues.extend(new_issues);
Self::check_expr_types(left, schema, tables, issues);
Self::check_expr_types(right, schema, tables, issues);
Self::check_expr_types(left, schema, tables, aliases, issues);
Self::check_expr_types(right, schema, tables, aliases, issues);
}
Expr::Nested(inner) => Self::check_expr_types(inner, schema, tables, issues),
Expr::Nested(inner) => Self::check_expr_types(inner, schema, tables, aliases, issues),
_ => {}
}
}
Expand Down
57 changes: 54 additions & 3 deletions tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,60 @@ fn l2_valid_multi_table_join() {
&schema,
)
.unwrap();
// Qualified columns use alias (u, p) which don't match schema table names directly.
// This is expected — aliases need separate resolution logic.
let _ = issues;
// Aliases `u`/`p` resolve to `users`/`posts`; every column exists, so a
// valid aliased join must produce no schema-binding issues.
assert!(
issues.is_empty(),
"Aliased join over existing columns must pass L2. Got: {:?}",
issues
);
}

#[test]
fn l2_alias_resolves_unknown_column() {
// The alias resolves to a real table, so a genuinely missing column on the
// aliased table is still caught (no false negative from alias resolution).
let plugin = get_plugin("sql").unwrap();
let schema = test_schema();
let issues = plugin
.schema_check("SELECT u.nonexistent FROM users u", &schema)
.unwrap();
assert!(
issues.iter().any(|i| i.message.contains("nonexistent")),
"Missing column on an aliased table must be flagged. Got: {:?}",
issues
);
}

#[test]
fn l3_type_check_resolves_alias() {
// `u.name` (text) compared to a numeric literal must be caught even though
// the column is referenced through the alias `u`.
let plugin = get_plugin("sql").unwrap();
let schema = test_schema();
let issues = plugin
.type_check("SELECT u.id FROM users u WHERE u.name = 42", &schema)
.unwrap();
assert!(
!issues.is_empty(),
"Type mismatch on an alias-qualified column must be flagged"
);
}

#[test]
fn l4_null_check_resolves_alias() {
// `u.email` is nullable; selecting it through the alias `u` without null
// handling must still raise a null-safety issue.
let plugin = get_plugin("sql").unwrap();
let schema = test_schema();
let issues = plugin
.null_check("SELECT u.email FROM users u", &schema)
.unwrap();
assert!(
issues.iter().any(|i| i.column == "email"),
"Nullable alias-qualified column must be flagged at L4. Got: {:?}",
issues
);
}

#[test]
Expand Down
Loading