diff --git a/src/plugins/sql.rs b/src/plugins/sql.rs index c44aea0..2f2eea4 100644 --- a/src/plugins/sql.rs +++ b/src/plugins/sql.rs @@ -103,6 +103,62 @@ impl SqlPlugin { } } + /// Build a qualifier→real-table map for the FROM/JOIN clauses of a query. + /// Each real table maps to itself, and each alias (e.g. `u` in + /// `FROM users u`) maps to its real table, so later passes can resolve a + /// qualified reference like `u.id` to the `users` table. + fn extract_table_aliases(statement: &Statement) -> Vec<(String, String)> { + let mut aliases = Vec::new(); + match statement { + Statement::Query(query) => { + if let SetExpr::Select(select) = query.body.as_ref() { + for twj in &select.from { + Self::collect_aliases_from_join(twj, &mut aliases); + } + } + } + Statement::Delete(delete) => match &delete.from { + sqlparser::ast::FromTable::WithFromKeyword(twjs) + | sqlparser::ast::FromTable::WithoutKeyword(twjs) => { + for twj in twjs { + Self::collect_aliases_from_join(twj, &mut aliases); + } + } + }, + Statement::Update { table, .. } => Self::collect_aliases_from_join(table, &mut aliases), + _ => {} + } + aliases + } + + fn collect_aliases_from_join(twj: &TableWithJoins, aliases: &mut Vec<(String, String)>) { + Self::collect_alias_from_factor(&twj.relation, aliases); + for join in &twj.joins { + Self::collect_alias_from_factor(&join.relation, aliases); + } + } + + fn collect_alias_from_factor(factor: &TableFactor, aliases: &mut Vec<(String, String)>) { + if let TableFactor::Table { name, alias, .. } = factor { + let real = name.to_string().to_lowercase(); + aliases.push((real.clone(), real.clone())); + if let Some(alias) = alias { + aliases.push((alias.name.value.to_lowercase(), real)); + } + } + } + + /// Resolve a table qualifier (a real table name or an alias) to its real + /// table name. An unknown qualifier falls back to itself, so a later schema + /// lookup still flags it rather than silently passing. + fn resolve_qualifier(aliases: &[(String, String)], qualifier: &str) -> String { + aliases + .iter() + .find(|(q, _)| q == qualifier) + .map(|(_, t)| t.clone()) + .unwrap_or_else(|| qualifier.to_string()) + } + /// Extract all column references from a statement. fn extract_column_refs(statement: &Statement) -> Vec<(Option, String)> { let mut cols = Vec::new(); @@ -172,12 +228,13 @@ impl SqlPlugin { right: &Expr, schema: &Schema, tables_in_query: &[String], + aliases: &[(String, String)], ) -> Vec { let mut issues = Vec::new(); // Get types of left and right if we can resolve them - let left_type = Self::infer_expr_type(left, schema, tables_in_query); - let right_type = Self::infer_expr_type(right, schema, tables_in_query); + let left_type = Self::infer_expr_type(left, schema, tables_in_query, aliases); + let right_type = Self::infer_expr_type(right, schema, tables_in_query, aliases); if let (Some(lt), Some(rt)) = (&left_type, &right_type) { let lt_cat = type_category(lt); @@ -225,7 +282,12 @@ impl SqlPlugin { } /// Attempt to infer the SQL type of an expression given the schema. - fn infer_expr_type(expr: &Expr, schema: &Schema, tables_in_query: &[String]) -> Option { + fn infer_expr_type( + expr: &Expr, + schema: &Schema, + tables_in_query: &[String], + aliases: &[(String, String)], + ) -> Option { match expr { Expr::Identifier(ident) => { let col_name = ident.value.to_lowercase(); @@ -240,7 +302,7 @@ impl SqlPlugin { None } Expr::CompoundIdentifier(parts) if parts.len() == 2 => { - let table_name = parts[0].value.to_lowercase(); + let table_name = Self::resolve_qualifier(aliases, &parts[0].value.to_lowercase()); let col_name = parts[1].value.to_lowercase(); if let Some(table) = schema.tables.iter().find(|t| t.name == table_name) && let Some(col) = table.columns.iter().find(|c| c.name == col_name) @@ -323,6 +385,7 @@ impl QueryLanguagePlugin for SqlPlugin { for stmt in &statements { // Check table references let table_refs = Self::extract_table_refs(stmt); + let aliases = Self::extract_table_aliases(stmt); for table_name in &table_refs { if !schema.tables.iter().any(|t| t.name == *table_name) { issues.push(SchemaIssue { @@ -334,10 +397,12 @@ impl QueryLanguagePlugin for SqlPlugin { // Check column references let col_refs = Self::extract_column_refs(stmt); for (table_qualifier, col_name) in &col_refs { - let tables_to_check: Vec<&str> = if let Some(tq) = table_qualifier { - vec![tq.as_str()] + // Resolve an alias qualifier (`u`) to its real table (`users`); + // an unqualified column is checked against every table in scope. + let tables_to_check: Vec = if let Some(tq) = table_qualifier { + vec![Self::resolve_qualifier(&aliases, tq)] } else { - table_refs.iter().map(|s| s.as_str()).collect() + table_refs.clone() }; let found = tables_to_check.iter().any(|tn| { @@ -378,13 +443,14 @@ impl QueryLanguagePlugin for SqlPlugin { for stmt in &statements { let table_refs = Self::extract_table_refs(stmt); + let aliases = Self::extract_table_aliases(stmt); // Check WHERE clause binary operations for type compatibility if let Statement::Query(query) = stmt && let SetExpr::Select(select) = query.body.as_ref() && let Some(ref selection) = select.selection { - Self::check_expr_types(selection, schema, &table_refs, &mut issues); + Self::check_expr_types(selection, schema, &table_refs, &aliases, &mut issues); } } @@ -398,6 +464,7 @@ impl QueryLanguagePlugin for SqlPlugin { for stmt in &statements { let table_refs = Self::extract_table_refs(stmt); + let aliases = Self::extract_table_aliases(stmt); // Check if SELECT includes nullable columns without COALESCE or IS NULL handling if let Statement::Query(q) = stmt @@ -428,6 +495,29 @@ impl QueryLanguagePlugin for SqlPlugin { } } } + // Alias-qualified projection (e.g. `u.email` in `FROM users u`): + // resolve the qualifier so nullability is still checked. + SelectItem::UnnamedExpr(Expr::CompoundIdentifier(parts)) + | SelectItem::ExprWithAlias { + expr: Expr::CompoundIdentifier(parts), + .. + } if parts.len() == 2 => { + let table_name = + Self::resolve_qualifier(&aliases, &parts[0].value.to_lowercase()); + let col_name = parts[1].value.to_lowercase(); + if let Some(table) = schema.tables.iter().find(|t| t.name == table_name) + && let Some(col) = table.columns.iter().find(|c| c.name == col_name) + && col.nullable + { + issues.push(NullIssue { + message: format!( + "Nullable column '{}' selected without COALESCE or null handling", + col_name + ), + column: col_name.clone(), + }); + } + } _ => {} } } @@ -444,16 +534,18 @@ impl SqlPlugin { expr: &Expr, schema: &Schema, tables: &[String], + aliases: &[(String, String)], issues: &mut Vec, ) { match expr { Expr::BinaryOp { left, op, right } => { - let new_issues = Self::check_binary_op_types(op, left, right, schema, tables); + let new_issues = + Self::check_binary_op_types(op, left, right, schema, tables, aliases); issues.extend(new_issues); - Self::check_expr_types(left, schema, tables, issues); - Self::check_expr_types(right, schema, tables, issues); + Self::check_expr_types(left, schema, tables, aliases, issues); + Self::check_expr_types(right, schema, tables, aliases, issues); } - Expr::Nested(inner) => Self::check_expr_types(inner, schema, tables, issues), + Expr::Nested(inner) => Self::check_expr_types(inner, schema, tables, aliases, issues), _ => {} } } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index e6c8ad1..cfc9c30 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -302,9 +302,60 @@ fn l2_valid_multi_table_join() { &schema, ) .unwrap(); - // Qualified columns use alias (u, p) which don't match schema table names directly. - // This is expected — aliases need separate resolution logic. - let _ = issues; + // Aliases `u`/`p` resolve to `users`/`posts`; every column exists, so a + // valid aliased join must produce no schema-binding issues. + assert!( + issues.is_empty(), + "Aliased join over existing columns must pass L2. Got: {:?}", + issues + ); +} + +#[test] +fn l2_alias_resolves_unknown_column() { + // The alias resolves to a real table, so a genuinely missing column on the + // aliased table is still caught (no false negative from alias resolution). + let plugin = get_plugin("sql").unwrap(); + let schema = test_schema(); + let issues = plugin + .schema_check("SELECT u.nonexistent FROM users u", &schema) + .unwrap(); + assert!( + issues.iter().any(|i| i.message.contains("nonexistent")), + "Missing column on an aliased table must be flagged. Got: {:?}", + issues + ); +} + +#[test] +fn l3_type_check_resolves_alias() { + // `u.name` (text) compared to a numeric literal must be caught even though + // the column is referenced through the alias `u`. + let plugin = get_plugin("sql").unwrap(); + let schema = test_schema(); + let issues = plugin + .type_check("SELECT u.id FROM users u WHERE u.name = 42", &schema) + .unwrap(); + assert!( + !issues.is_empty(), + "Type mismatch on an alias-qualified column must be flagged" + ); +} + +#[test] +fn l4_null_check_resolves_alias() { + // `u.email` is nullable; selecting it through the alias `u` without null + // handling must still raise a null-safety issue. + let plugin = get_plugin("sql").unwrap(); + let schema = test_schema(); + let issues = plugin + .null_check("SELECT u.email FROM users u", &schema) + .unwrap(); + assert!( + issues.iter().any(|i| i.column == "email"), + "Nullable alias-qualified column must be flagged at L4. Got: {:?}", + issues + ); } #[test]